From 9ec93f08623bf7ee7133c10170b6e6e3f552bc01 Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Thu, 5 Sep 2024 15:01:24 +0200 Subject: [PATCH 01/16] PB-756: CRUD API for collection assets Add collection asset endpoints for stac v1. Duplicate asset views and serializers for collection assets. --- app/config/urls.py | 6 + app/stac_api/serializers.py | 149 +++- app/stac_api/serializers_utils.py | 10 + app/stac_api/signals.py | 10 + app/stac_api/urls.py | 16 +- app/stac_api/views.py | 140 ++++ app/stac_api/views_test.py | 8 + app/tests/tests_10/base_test.py | 50 ++ app/tests/tests_10/data_factory.py | 13 +- .../test_collection_assets_endpoint.py | 729 ++++++++++++++++++ 10 files changed, 1117 insertions(+), 14 deletions(-) create mode 100644 app/tests/tests_10/test_collection_assets_endpoint.py diff --git a/app/config/urls.py b/app/config/urls.py index a0eb4c8a..d28f7c8b 100644 --- a/app/config/urls.py +++ b/app/config/urls.py @@ -39,6 +39,7 @@ def checker(request): import debug_toolbar from stac_api.views_test import TestAssetUpsertHttp500 + from stac_api.views_test import TestCollectionAssetUpsertHttp500 from stac_api.views_test import TestCollectionUpsertHttp500 from stac_api.views_test import TestHttp500 from stac_api.views_test import TestItemUpsertHttp500 @@ -61,6 +62,11 @@ def checker(request): TestAssetUpsertHttp500.as_view(), name='test-asset-detail-http-500' ), + path( + 'tests/test_collection_asset_upsert_http_500//', + TestCollectionAssetUpsertHttp500.as_view(), + name='test-collection-asset-detail-http-500' + ), # Add v0.9 namespace to test routes. path( 'tests/v0.9/test_asset_upsert_http_500///', diff --git a/app/stac_api/serializers.py b/app/stac_api/serializers.py index 343f059e..c24c531e 100644 --- a/app/stac_api/serializers.py +++ b/app/stac_api/serializers.py @@ -15,6 +15,7 @@ from stac_api.models import Asset from stac_api.models import AssetUpload from stac_api.models import Collection +from stac_api.models import CollectionAsset from stac_api.models import CollectionLink from stac_api.models import Item from stac_api.models import ItemLink @@ -380,6 +381,107 @@ def get_fields(self): return fields +class CollectionAssetBaseSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): + '''Collection asset serializer base class + ''' + + class Meta: + model = CollectionAsset + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'roles', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated', + ] + validators = [] # Remove a default "unique together" constraint. + # (see: + # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField(source='name', max_length=255, validators=[validate_asset_name]) + title = serializers.CharField( + required=False, max_length=255, allow_null=True, allow_blank=False + ) + description = serializers.CharField(required=False, allow_blank=False, allow_null=True) + # Can't be a ChoiceField, as the validate method normalizes the MIME string only after it + # is read. Consistency is nevertheless guaranteed by the validate() and validate_type() methods. + type = serializers.CharField( + source='media_type', required=True, allow_null=False, allow_blank=False + ) + # Here we need to explicitely define these fields with the source, because they are renamed + # in the get_fields() method + proj_epsg = serializers.IntegerField(source='proj_epsg', allow_null=True, required=False) + # read only fields + checksum_multihash = serializers.CharField(source='checksum_multihash', read_only=True) + href = HrefField(source='file', read_only=True) + created = serializers.DateTimeField(read_only=True) + updated = serializers.DateTimeField(read_only=True) + + # helper variable to provide the collection for upsert validation + # see views.AssetDetail.perform_upsert + collection = None + + def create(self, validated_data): + asset = validate_uniqueness_and_create(Asset, validated_data) + return asset + + def update_or_create(self, look_up, validated_data): + """ + Update or create the asset object selected by kwargs and return the instance. + When no asset object matching the kwargs selection, a new asset is created. + Args: + validated_data: dict + Copy of the validated_data to use for update + kwargs: dict + Object selection arguments (NOTE: the selection arguments must match a unique + object in DB otherwise an IntegrityError will be raised) + Returns: tuple + Asset instance and True if created otherwise false + """ + asset, created = CollectionAsset.objects.update_or_create( + **look_up, + defaults=validated_data, + ) + return asset, created + + def validate_type(self, value): + ''' Validates the the field "type" + ''' + return normalize_and_validate_media_type(value) + + def validate(self, attrs): + name = attrs['name'] if not self.partial else attrs.get('name', self.instance.name) + media_type = attrs['media_type'] if not self.partial else attrs.get( + 'media_type', self.instance.media_type + ) + validate_asset_name_with_media_type(name, media_type) + + validate_json_payload(self) + + return attrs + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['proj:epsg'] = fields.pop('proj_epsg') + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + fields.pop('roles', None) + + return fields + + class AssetSerializer(AssetBaseSerializer): '''Asset serializer for the asset views @@ -435,6 +537,27 @@ def validate(self, attrs): return super().validate(attrs) +class CollectionAssetSerializer(CollectionAssetBaseSerializer): + '''Collection Asset serializer for the collection asset views + + This serializer adds the links list attribute. + ''' + + def to_representation(self, instance): + collection = instance.collection.name + name = instance.name + request = self.context.get("request") + representation = super().to_representation(instance) + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'] = get_relation_links( + request, 'collection-asset-detail', [collection, name] + ) + return representation + + class AssetsForItemSerializer(AssetBaseSerializer): '''Assets serializer for nesting them inside the item @@ -462,6 +585,30 @@ class Meta: ] +class CollectionAssetsForCollectionSerializer(CollectionAssetBaseSerializer): + '''Collection assets serializer for nesting them inside the collection + + Assets should be nested inside their collection but using a dictionary instead of a list and + without links. + ''' + + class Meta: + model = CollectionAsset + list_serializer_class = AssetsDictSerializer + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'roles', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated' + ] + + class CollectionSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): class Meta: @@ -511,7 +658,7 @@ class Meta: stac_extensions = serializers.SerializerMethodField(read_only=True) stac_version = serializers.SerializerMethodField(read_only=True) itemType = serializers.ReadOnlyField(default="Feature") # pylint: disable=invalid-name - assets = AssetsForItemSerializer(many=True, read_only=True) + assets = CollectionAssetsForCollectionSerializer(many=True, read_only=True) def get_crs(self, obj): return ["http://www.opengis.net/def/crs/OGC/1.3/CRS84"] diff --git a/app/stac_api/serializers_utils.py b/app/stac_api/serializers_utils.py index 03f950ae..68b0a2b6 100644 --- a/app/stac_api/serializers_utils.py +++ b/app/stac_api/serializers_utils.py @@ -72,6 +72,14 @@ def update_or_create_links(model, instance, instance_type, links_data): 'parent': 'landing-page', 'browser': 'browser-collection', }, + 'collection-assets-list': { + 'parent': 'collection-detail', + 'browser': None, + }, + 'collection-asset-detail': { + 'parent': 'collection-detail', + 'browser': None, + }, 'items-list': { 'parent': 'collection-detail', 'browser': 'browser-collection', @@ -107,6 +115,8 @@ def get_parent_link(request, view, view_args=()): ''' def parent_args(view, args): + if view.startswith('collection-asset'): + return args[:1] if view.startswith('item'): return args[:1] if view.startswith('asset'): diff --git a/app/stac_api/signals.py b/app/stac_api/signals.py index 8c8063b5..2005869a 100644 --- a/app/stac_api/signals.py +++ b/app/stac_api/signals.py @@ -6,6 +6,7 @@ from stac_api.models import Asset from stac_api.models import AssetUpload +from stac_api.models import CollectionAsset logger = logging.getLogger(__name__) @@ -36,3 +37,12 @@ def delete_s3_asset(sender, instance, **kwargs): # hence it has to be done here. logger.info("The asset %s is deleted from s3", instance.file.name) instance.file.delete(save=False) + + +@receiver(pre_delete, sender=CollectionAsset) +def delete_s3_collection_asset(sender, instance, **kwargs): + # The file is not automatically deleted by Django + # when the object holding its reference is deleted + # hence it has to be done here. + logger.info("The collection asset %s is deleted from s3", instance.file.name) + instance.file.delete(save=False) diff --git a/app/stac_api/urls.py b/app/stac_api/urls.py index bd23ec38..9e3328e4 100644 --- a/app/stac_api/urls.py +++ b/app/stac_api/urls.py @@ -11,6 +11,8 @@ from stac_api.views import AssetUploadDetail from stac_api.views import AssetUploadPartsList from stac_api.views import AssetUploadsList +from stac_api.views import CollectionAssetDetail +from stac_api.views import CollectionAssetsList from stac_api.views import CollectionDetail from stac_api.views import CollectionList from stac_api.views import ConformancePageDetail @@ -41,7 +43,19 @@ path("/assets/", include(asset_urls)) ] +collection_asset_urls = [ + path("", CollectionAssetDetail.as_view(), name='collection-asset-detail'), +] + collection_urls = [ + path("", CollectionDetail.as_view(), name='collection-detail'), + path("/items", ItemsList.as_view(), name='items-list'), + path("/items/", include(item_urls)), + path("/assets", CollectionAssetsList.as_view(), name='collection-assets-list'), + path("/assets/", include(collection_asset_urls)) +] + +collection_urls_v09 = [ path("", CollectionDetail.as_view(), name='collection-detail'), path("/items", ItemsList.as_view(), name='items-list'), path("/items/", include(item_urls)) @@ -58,7 +72,7 @@ path("conformance", ConformancePageDetail.as_view(), name='conformance'), path("search", SearchList.as_view(), name='search-list'), path("collections", CollectionList.as_view(), name='collections-list'), - path("collections/", include(collection_urls)), + path("collections/", include(collection_urls_v09)), path("update-extent", recalculate_extent) ], "v0.9"), diff --git a/app/stac_api/views.py b/app/stac_api/views.py index 2d75d099..72b016df 100644 --- a/app/stac_api/views.py +++ b/app/stac_api/views.py @@ -30,6 +30,7 @@ from stac_api.models import Asset from stac_api.models import AssetUpload from stac_api.models import Collection +from stac_api.models import CollectionAsset from stac_api.models import Item from stac_api.models import LandingPage from stac_api.pagination import ExtApiPagination @@ -38,6 +39,7 @@ from stac_api.serializers import AssetSerializer from stac_api.serializers import AssetUploadPartsSerializer from stac_api.serializers import AssetUploadSerializer +from stac_api.serializers import CollectionAssetSerializer from stac_api.serializers import CollectionSerializer from stac_api.serializers import ConformancePageSerializer from stac_api.serializers import ItemSerializer @@ -45,6 +47,7 @@ from stac_api.serializers_utils import get_relation_links from stac_api.utils import call_calculate_extent from stac_api.utils import get_asset_path +from stac_api.utils import get_collection_asset_path from stac_api.utils import harmonize_post_get_for_search from stac_api.utils import is_api_version_1 from stac_api.utils import select_s3_bucket @@ -111,6 +114,31 @@ def get_item_etag(request, *args, **kwargs): return tag +def get_collection_asset_etag(request, *args, **kwargs): + '''Get the ETag for a collection asset object + + The ETag is an UUID4 computed on each object changes + ''' + tag = get_etag( + CollectionAsset.objects.filter( + collection__name=kwargs['collection_name'], name=kwargs['asset_name'] + ) + ) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from get_collection_asset_etag():\n%s", + CollectionAsset.objects.filter(name=kwargs['asset_name'] + ).explain(verbose=True, analyze=True) + ) + logger.debug( + "The corresponding SQL statement:\n%s", + CollectionAsset.objects.filter(name=kwargs['asset_name']).query + ) + + return tag + + def get_asset_etag(request, *args, **kwargs): '''Get the ETag for a asset object @@ -527,6 +555,34 @@ def get(self, request, *args, **kwargs): return response +class CollectionAssetsList(generics.GenericAPIView): + name = 'collection-assets-list' # this name must match the name in urls.py + serializer_class = CollectionAssetSerializer + pagination_class = None + + def get_queryset(self): + # filter based on the url + return CollectionAsset.objects.filter(collection__name=self.kwargs['collection_name'] + ).order_by('name') + + def get(self, request, *args, **kwargs): + validate_collection(self.kwargs) + + queryset = self.filter_queryset(self.get_queryset()) + update_interval = Collection.objects.values('update_interval').get( + name=self.kwargs['collection_name'] + )['update_interval'] + serializer = self.get_serializer(queryset, many=True) + + data = { + 'assets': serializer.data, + 'links': get_relation_links(request, self.name, [self.kwargs['collection_name']]) + } + response = Response(data) + views_mixins.patch_cache_settings_by_update_interval(response, update_interval) + return response + + class AssetDetail( generics.GenericAPIView, views_mixins.UpdateInsertModelMixin, @@ -640,6 +696,90 @@ def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) +class CollectionAssetDetail( + generics.GenericAPIView, + views_mixins.UpdateInsertModelMixin, + views_mixins.DestroyModelMixin, + views_mixins.RetrieveModelDynCacheMixin +): + # this name must match the name in urls.py and is used by the DestroyModelMixin + name = 'collection-asset-detail' + serializer_class = CollectionAssetSerializer + lookup_url_kwarg = "asset_name" + lookup_field = "name" + + def get_queryset(self): + # filter based on the url + return CollectionAsset.objects.filter(collection__name=self.kwargs['collection_name']) + + def get_serializer(self, *args, **kwargs): + serializer_class = self.get_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + serializer = serializer_class(*args, **kwargs) + + # for the validation the serializer needs to know the collection of the + # asset. In case of inserting, the asset doesn't exist and thus the collection + # can't be read from the instance, which is why we pass the collection manually + # here. See serializers.AssetBaseSerializer._validate_href_field + serializer.collection = collection + return serializer + + def perform_update(self, serializer): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'asset': self.kwargs['asset_name'] + } + ) + return serializer.save( + collection=collection, + file=get_collection_asset_path(collection, self.kwargs['asset_name']) + ) + + def perform_upsert(self, serializer, lookup): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'asset': self.kwargs['asset_name'] + } + ) + lookup['collection__name'] = collection.name + + return serializer.upsert( + lookup, + collection=collection, + file=get_collection_asset_path(collection, self.kwargs['asset_name']) + ) + + @etag(get_collection_asset_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_asset_etag) + def put(self, request, *args, **kwargs): + return self.upsert(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_asset_etag) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_asset_etag) + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + class AssetUploadBase(generics.GenericAPIView): serializer_class = AssetUploadSerializer lookup_url_kwarg = "upload_id" diff --git a/app/stac_api/views_test.py b/app/stac_api/views_test.py index a23f9e3e..d1beefd2 100644 --- a/app/stac_api/views_test.py +++ b/app/stac_api/views_test.py @@ -4,6 +4,7 @@ from stac_api.models import LandingPage from stac_api.views import AssetDetail +from stac_api.views import CollectionAssetDetail from stac_api.views import CollectionDetail from stac_api.views import ItemDetail @@ -39,3 +40,10 @@ class TestAssetUpsertHttp500(AssetDetail): def perform_upsert(self, serializer, lookup): super().perform_upsert(serializer, lookup) raise AttributeError('test exception') + + +class TestCollectionAssetUpsertHttp500(CollectionAssetDetail): + + def perform_upsert(self, serializer, lookup): + super().perform_upsert(serializer, lookup) + raise AttributeError('test exception') diff --git a/app/tests/tests_10/base_test.py b/app/tests/tests_10/base_test.py index bbfc86ba..cf02f2f8 100644 --- a/app/tests/tests_10/base_test.py +++ b/app/tests/tests_10/base_test.py @@ -302,6 +302,56 @@ def check_stac_asset(self, expected, current, collection, item, ignore=None): ] self._check_stac_links('asset.links', links, current['links']) + def check_stac_collection_asset(self, expected, current, collection, ignore=None): + '''Check a STAC Collection Asset data + + Check if the `current` collection asset data match the `expected`. This check is a subset + check which means that if a value is missing from `current`, then it raises a Test Assert, + while if a value is in `current` but not in `expected`, the test passed. The functions + knows also the STAC Spec and does some check based on it. + + Args: + expected: dict + Expected STAC Asset + current: dict + Current STAC Asset to test + ignore: list(string) | None + List of keys to ignore in the test + ''' + if ignore is None: + ignore = [] + self._check_stac_dictsubset('asset', expected, current, ignore=ignore) + + # check required fields + for key in ['links', 'id', 'type', 'href']: + if key in ignore: + logger.info('Ignoring key %s in asset', key) + continue + self.assertIn(key, current, msg=f'Asset {key} is missing') + for date_field in ['created', 'updated']: + if key in ignore: + logger.info('Ignoring key %s in asset', key) + continue + self.assertIn(date_field, current, msg=f'Asset {date_field} is missing') + self.assertTrue( + fromisoformat(current[date_field]), + msg=f"The asset field {date_field} has an invalid date" + ) + if 'links' not in ignore: + name = current['id'] + links = [ + { + 'rel': 'self', + 'href': f'{TEST_LINK_ROOT_HREF}/collections/{collection}/assets/{name}' + }, + TEST_LINK_ROOT, + { + 'rel': 'parent', + 'href': f'{TEST_LINK_ROOT_HREF}/collections/{collection}', + }, + ] + self._check_stac_links('asset.links', links, current['links']) + def _check_stac_dictsubset(self, parent_path, expected, current, ignore=None): for key, value in expected.items(): path = f'{parent_path}.{key}' diff --git a/app/tests/tests_10/data_factory.py b/app/tests/tests_10/data_factory.py index 3a463fc1..2eebaa92 100644 --- a/app/tests/tests_10/data_factory.py +++ b/app/tests/tests_10/data_factory.py @@ -801,23 +801,12 @@ class CollectionAssetSample(SampleData): samples_dict = collection_asset_samples key_mapping = { 'name': 'id', - 'eo_gsd': 'gsd', - 'geoadmin_variant': 'geoadmin:variant', - 'geoadmin_lang': 'geoadmin:lang', 'proj_epsg': 'proj:epsg', 'media_type': 'type', 'checksum_multihash': 'file:checksum', 'file': 'href' } - optional_fields = [ - 'title', - 'description', - 'eo_gsd', - 'geoadmin_variant', - 'geoadmin_lang', - 'proj_epsg', - 'checksum_multihash' - ] + optional_fields = ['title', 'description', 'proj_epsg', 'checksum_multihash'] read_only_fields = ['created', 'updated', 'href', 'file:checksum'] def __init__(self, collection, sample='asset-1', name=None, required_only=False, **kwargs): diff --git a/app/tests/tests_10/test_collection_assets_endpoint.py b/app/tests/tests_10/test_collection_assets_endpoint.py new file mode 100644 index 00000000..1596a8d9 --- /dev/null +++ b/app/tests/tests_10/test_collection_assets_endpoint.py @@ -0,0 +1,729 @@ +import logging +from datetime import datetime +from json import dumps +from json import loads +from pprint import pformat + +from django.contrib.auth import get_user_model +from django.test import Client +from django.urls import reverse + +from stac_api.models import CollectionAsset +from stac_api.utils import get_collection_asset_path +from stac_api.utils import utc_aware + +from tests.tests_10.base_test import STAC_BASE_V +from tests.tests_10.base_test import StacBaseTestCase +from tests.tests_10.base_test import StacBaseTransactionTestCase +from tests.tests_10.data_factory import Factory +from tests.tests_10.utils import reverse_version +from tests.utils import S3TestMixin +from tests.utils import client_login +from tests.utils import disableLogger +from tests.utils import mock_s3_asset_file + +logger = logging.getLogger(__name__) + + +def to_dict(input_ordered_dict): + return loads(dumps(input_ordered_dict)) + + +class CollectionAssetsEndpointTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.client = Client() + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset_1 = self.factory.create_collection_asset_sample( + collection=self.collection, name="asset-1.tiff", db_create=True + ) + self.maxDiff = None # pylint: disable=invalid-name + + def test_assets_endpoint(self): + collection_name = self.collection.name + # To test the assert ordering make sure to not create them in ascent order + asset_2 = self.factory.create_collection_asset_sample( + collection=self.collection, sample='asset-2', name="asset-2.txt", db_create=True + ) + asset_3 = self.factory.create_collection_asset_sample( + collection=self.collection, name="asset-0.pdf", sample='asset-3', db_create=True + ) + response = self.client.get(f"/{STAC_BASE_V}/collections/{collection_name}/assets") + self.assertStatusCode(200, response) + json_data = response.json() + logger.debug('Response (%s):\n%s', type(json_data), pformat(json_data)) + + self.assertIn('assets', json_data, msg='assets is missing in response') + self.assertEqual( + 3, len(json_data['assets']), msg='Number of assets doesn\'t match the expected' + ) + + # Check that the output is sorted by name + asset_ids = [asset['id'] for asset in json_data['assets']] + self.assertListEqual(asset_ids, sorted(asset_ids), msg="Assets are not sorted by ID") + + asset_samples = sorted([self.asset_1, asset_2, asset_3], key=lambda asset: asset['name']) + for i, asset in enumerate(asset_samples): + # self.check_stac_asset(asset.json, json_data['assets'][i], collection_name, item_name) + self.check_stac_collection_asset(asset.json, json_data['assets'][i], collection_name) + + def test_assets_endpoint_collection_does_not_exist(self): + collection_name = "non-existent" + response = self.client.get(f"/{STAC_BASE_V}/collections/{collection_name}/assets") + self.assertStatusCode(404, response) + + def test_single_asset_endpoint(self): + collection_name = self.collection.name + asset_name = self.asset_1["name"] + response = self.client.get( + f"/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}" + ) + json_data = response.json() + self.assertStatusCode(200, response) + logger.debug('Response (%s):\n%s', type(json_data), pformat(json_data)) + + self.check_stac_collection_asset(self.asset_1.json, json_data, collection_name) + + # The ETag change between each test call due to the created, updated time that are in the + # hash computation of the ETag + self.assertEtagHeader(None, response) + + +class CollectionAssetsUnimplementedEndpointTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_unimplemented_post(self): + collection_name = self.collection.name + asset = self.factory.create_collection_asset_sample( + collection=self.collection, required_only=True + ) + response = self.client.post( + f'/{STAC_BASE_V}/collections/{collection_name}/assets', + data=asset.get_json('post'), + content_type="application/json" + ) + self.assertStatusCode(405, response) + + +class CollectionAssetsCreateEndpointTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_upsert_create_only_required(self): + collection_name = self.collection.name + asset = self.factory.create_collection_asset_sample( + collection=self.collection, required_only=True + ) + path = \ + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset["name"]}' + json_to_send = asset.get_json('put') + # Send a non normalized form of the type to see if it is also accepted + json_to_send['type'] = 'image/TIFF;application=geotiff; Profile=cloud-optimized' + response = self.client.put(path, data=json_to_send, content_type="application/json") + json_data = response.json() + self.assertStatusCode(201, response) + self.assertLocationHeader(f"{path}", response) + self.check_stac_collection_asset(asset.json, json_data, collection_name) + + # Check the data by reading it back + response = self.client.get(response['Location']) + json_data = response.json() + self.assertStatusCode(200, response) + self.check_stac_collection_asset(asset.json, json_data, collection_name) + + # make sure that the optional fields are not present + self.assertNotIn('proj:epsg', json_data) + self.assertNotIn('description', json_data) + self.assertNotIn('title', json_data) + self.assertNotIn('file:checksum', json_data) + + def test_asset_upsert_create(self): + collection = self.collection + asset = self.factory.create_collection_asset_sample( + collection=self.collection, sample='asset-no-checksum', create_asset_file=False + ) + asset_name = asset['name'] + + response = self.client.get( + reverse_version('collection-asset-detail', args=[collection.name, asset_name]) + ) + # Check that assert does not exist already + self.assertStatusCode(404, response) + + # Check also, that the asset does not exist in the DB already + self.assertFalse( + CollectionAsset.objects.filter(name=asset_name).exists(), + msg="Collection asset already exists" + ) + + # Now use upsert to create the new asset + response = self.client.put( + reverse_version('collection-asset-detail', args=[collection.name, asset_name]), + data=asset.get_json('put'), + content_type="application/json" + ) + json_data = response.json() + self.assertStatusCode(201, response) + self.assertLocationHeader( + reverse_version('collection-asset-detail', args=[collection.name, asset_name]), + response + ) + self.check_stac_collection_asset(asset.json, json_data, collection.name) + + # make sure that all optional fields are present + self.assertIn('proj:epsg', json_data) + self.assertIn('description', json_data) + self.assertIn('title', json_data) + + # Checksum multihash is set by the AssetUpload later on + self.assertNotIn('file:checksum', json_data) + + # Check the data by reading it back + response = self.client.get(response['Location']) + json_data = response.json() + self.assertStatusCode(200, response) + self.check_stac_collection_asset(asset.json, json_data, collection.name) + + def test_asset_upsert_create_non_existing_parent_collection_in_path(self): + asset = self.factory.create_collection_asset_sample( + collection=self.collection, create_asset_file=False + ) + asset_name = asset['name'] + + path = (f'/{STAC_BASE_V}/collections/non-existing-collection/assets/' + f'{asset_name}') + + # Check that asset does not exist already + response = self.client.get(path) + self.assertStatusCode(404, response) + + # Check also, that the asset does not exist in the DB already + self.assertFalse( + CollectionAsset.objects.filter(name=asset_name).exists(), + msg="Deleted colelction asset still found in DB" + ) + + # Now use upsert to create the new asset + response = self.client.put( + path, data=asset.get_json('post'), content_type="application/json" + ) + self.assertStatusCode(404, response) + + def test_asset_upsert_create_empty_string(self): + collection_name = self.collection.name + asset = self.factory.create_collection_asset_sample( + collection=self.collection, required_only=True, description='', title='' + ) + + path = \ + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset["name"]}' + response = self.client.put( + path, data=asset.get_json('put'), content_type="application/json" + ) + self.assertStatusCode(400, response) + json_data = response.json() + for field in ['description', 'title']: + self.assertIn(field, json_data['description'], msg=f'Field {field} error missing') + + def invalid_request_wrapper(self, sample_name, expected_error_messages, **extra_params): + collection_name = self.collection.name + asset = self.factory.create_collection_asset_sample( + collection=self.collection, sample=sample_name, **extra_params + ) + + path = \ + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset["name"]}' + response = self.client.put( + path, data=asset.get_json('put'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + expected_error_messages, + response.json()['description'], + msg='Unexpected error message', + ) + + # Make sure that the asset is not found in DB + self.assertFalse( + CollectionAsset.objects.filter(name=asset.json['id']).exists(), + msg="Invalid asset has been created in DB" + ) + + def test_asset_upsert_create_invalid_data(self): + self.invalid_request_wrapper( + 'asset-invalid', { + 'proj:epsg': ['A valid integer is required.'], + 'type': ['Invalid media type "dummy"'] + } + ) + + def test_asset_upsert_create_invalid_type(self): + media_type_str = "image/tiff; application=Geotiff; profile=cloud-optimized" + self.invalid_request_wrapper( + 'asset-invalid-type', {'type': [f'Invalid media type "{media_type_str}"']} + ) + + def test_asset_upsert_create_type_extension_mismatch(self): + media_type_str = "application/gml+xml" + self.invalid_request_wrapper( + 'asset-invalid-type', + { + 'non_field_errors': [ + f"Invalid id extension '.tiff', id must match its media type {media_type_str}" + ] + }, + media_type=media_type_str, + # must be overridden, else extension will automatically match the type + name='asset-invalid-type.tiff' + ) + + +class CollectionAssetsUpdateEndpointAssetFileTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample(db_create=True) + self.asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, db_create=True + ) + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_endpoint_patch_put_href(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + asset_sample = self.asset.copy() + + put_payload = asset_sample.get_json('put') + put_payload['href'] = 'https://testserver/non-existing-asset' + patch_payload = {'href': 'https://testserver/non-existing-asset'} + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch(path, data=patch_payload, content_type="application/json") + self.assertStatusCode(400, response) + description = response.json()['description'] + self.assertIn('href', description, msg=f'Unexpected field error {description}') + self.assertEqual( + "Found read-only property in payload", + description['href'][0], + msg="Unexpected error message" + ) + + response = self.client.put(path, data=put_payload, content_type="application/json") + self.assertStatusCode(400, response) + description = response.json()['description'] + self.assertIn('href', description, msg=f'Unexpected field error {description}') + self.assertEqual( + "Found read-only property in payload", + description['href'][0], + msg="Unexpected error message" + ) + + +class CollectionAssetsUpdateEndpointTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample(db_create=True) + self.asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, db_create=True + ) + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_endpoint_put(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + checksum_multihash=self.asset['checksum_multihash'], + create_asset_file=False + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put( + path, data=changed_asset.get_json('put'), content_type="application/json" + ) + json_data = response.json() + self.assertStatusCode(200, response) + self.check_stac_collection_asset(changed_asset.json, json_data, collection_name) + + # Check the data by reading it back + response = self.client.get(path) + json_data = response.json() + self.assertStatusCode(200, response) + self.check_stac_collection_asset(changed_asset.json, json_data, collection_name) + + def test_asset_endpoint_put_extra_payload(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + checksum_multihash=self.asset['checksum_multihash'], + extra_attribute='not allowed', + create_asset_file=False + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put( + path, data=changed_asset.get_json('put'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'extra_attribute': ['Unexpected property in payload']}, + response.json()['description'], + msg='Unexpected error message') + + def test_asset_endpoint_put_read_only_in_payload(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + created=utc_aware(datetime.utcnow()), + create_asset_file=False, + checksum_multihash=self.asset['checksum_multihash'], + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put( + path, + data=changed_asset.get_json('put', keep_read_only=True), + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({ + 'created': ['Found read-only property in payload'], + 'file:checksum': ['Found read-only property in payload'] + }, + response.json()['description'], + msg='Unexpected error message') + + def test_asset_endpoint_put_rename_asset(self): + # rename should not be allowed + collection_name = self.collection['name'] + asset_name = self.asset['name'] + new_asset_name = "new-asset-name.txt" + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=new_asset_name, + sample='asset-1-updated', + checksum_multihash=self.asset['checksum_multihash'] + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put( + path, data=changed_asset.get_json('put'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'id': 'Renaming is not allowed'}, + response.json()['description'], + msg='Unexpected error message') + + # Check the data by reading it back + response = self.client.get( + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + ) + json_data = response.json() + self.assertStatusCode(200, response) + + self.assertEqual(asset_name, json_data['id']) + + # Check the data that no new entry exist + response = self.client.get( + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{new_asset_name}' + ) + + # 404 - not found + self.assertStatusCode(404, response) + + def test_asset_endpoint_patch_rename_asset(self): + # rename should not be allowed + collection_name = self.collection['name'] + asset_name = self.asset['name'] + new_asset_name = "new-asset-name.txt" + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, name=new_asset_name, sample='asset-1-updated' + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch( + path, data=changed_asset.get_json('patch'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'id': 'Renaming is not allowed'}, + response.json()['description'], + msg='Unexpected error message') + + # Check the data by reading it back + response = self.client.get( + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + ) + json_data = response.json() + self.assertStatusCode(200, response) + + self.assertEqual(asset_name, json_data['id']) + + # Check the data that no new entry exist + response = self.client.get( + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{new_asset_name}' + ) + + # 404 - not found + self.assertStatusCode(404, response) + + def test_asset_endpoint_patch_extra_payload(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + extra_payload='invalid' + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch( + path, data=changed_asset.get_json('patch'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'extra_payload': ['Unexpected property in payload']}, + response.json()['description'], + msg='Unexpected error message') + + def test_asset_endpoint_patch_read_only_in_payload(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + created=utc_aware(datetime.utcnow()) + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch( + path, + data=changed_asset.get_json('patch', keep_read_only=True), + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'created': ['Found read-only property in payload']}, + response.json()['description'], + msg='Unexpected error message') + + def test_asset_atomic_upsert_create_500(self): + sample = self.factory.create_collection_asset_sample( + self.collection.model, create_asset_file=True + ) + + # the dataset to update does not exist yet + with self.settings(DEBUG_PROPAGATE_API_EXCEPTIONS=True), disableLogger('stac_api.apps'): + response = self.client.put( + reverse( + 'test-collection-asset-detail-http-500', + args=[self.collection['name'], sample['name']] + ), + data=sample.get_json('put'), + content_type='application/json' + ) + self.assertStatusCode(500, response) + self.assertEqual(response.json()['description'], "AttributeError('test exception')") + + # Make sure that the ressource has not been created + response = self.client.get( + reverse_version( + 'collection-asset-detail', args=[self.collection['name'], sample['name']] + ) + ) + self.assertStatusCode(404, response) + + def test_asset_atomic_upsert_update_500(self): + sample = self.factory.create_collection_asset_sample( + self.collection.model, name=self.asset['name'], create_asset_file=True + ) + + # Make sure samples is different from actual data + self.assertNotEqual(sample.attributes, self.asset.attributes) + + # the dataset to update does not exist yet + with self.settings(DEBUG_PROPAGATE_API_EXCEPTIONS=True), disableLogger('stac_api.apps'): + # because we explicitely test a crash here we don't want to print a CRITICAL log on the + # console therefore disable it. + response = self.client.put( + reverse( + 'test-collection-asset-detail-http-500', + args=[self.collection['name'], sample['name']] + ), + data=sample.get_json('put'), + content_type='application/json' + ) + self.assertStatusCode(500, response) + self.assertEqual(response.json()['description'], "AttributeError('test exception')") + + # Make sure that the ressource has not been created + response = self.client.get( + reverse_version( + 'collection-asset-detail', args=[self.collection['name'], sample['name']] + ) + ) + self.assertStatusCode(200, response) + self.check_stac_collection_asset( + self.asset.json, response.json(), self.collection['name'], ignore=['item'] + ) + + +class CollectionAssetRaceConditionTest(StacBaseTransactionTestCase): + + def setUp(self): + self.username = 'user' + self.password = 'dummy-password' + get_user_model().objects.create_superuser(self.username, password=self.password) + self.factory = Factory() + self.collection_sample = self.factory.create_collection_sample( + sample='collection-2', db_create=True + ) + + def test_asset_upsert_race_condition(self): + workers = 5 + status_201 = 0 + asset_sample = self.factory.create_collection_asset_sample( + self.collection_sample.model, + sample='asset-no-checksum', + ) + + def asset_atomic_upsert_test(worker): + # This method run on separate thread therefore it requires to create a new client and + # to login it for each call. + client = Client() + client.login(username=self.username, password=self.password) + return client.put( + reverse_version( + 'collection-asset-detail', + args=[self.collection_sample['name'], asset_sample['name']] + ), + data=asset_sample.get_json('put'), + content_type='application/json' + ) + + # We call the PUT asset several times in parallel with the same data to make sure + # that we don't have any race condition. + responses, errors = self.run_parallel(workers, asset_atomic_upsert_test) + + for worker, response in responses: + if response.status_code == 201: + status_201 += 1 + self.assertIn( + response.status_code, [200, 201], + msg=f'Unexpected response status code {response.status_code} for worker {worker}' + ) + self.check_stac_collection_asset( + asset_sample.json, response.json(), self.collection_sample['name'], ignore=['item'] + ) + self.assertEqual(status_201, 1, msg="Not only one upsert did a create !") + + +class CollectionAssetsDeleteEndpointTestCase(StacBaseTestCase, S3TestMixin): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample(collection=self.collection).model + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_endpoint_delete_asset(self): + collection_name = self.collection.name + asset_name = self.asset.name + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + s3_path = get_collection_asset_path(self.collection, asset_name) + self.assertS3ObjectExists(s3_path) + response = self.client.delete(path) + self.assertStatusCode(200, response) + + # Check that is has really been deleted + self.assertS3ObjectNotExists(s3_path) + response = self.client.get(path) + self.assertStatusCode(404, response) + + # Check that it is really not to be found in DB + self.assertFalse( + CollectionAsset.objects.filter(name=self.asset.name).exists(), + msg="Deleted asset still found in DB" + ) + + def test_asset_endpoint_delete_asset_invalid_name(self): + collection_name = self.collection.name + path = f"/{STAC_BASE_V}/collections/{collection_name}/assets/non-existent-asset" + response = self.client.delete(path) + self.assertStatusCode(404, response) + + +class CollectionAssetsEndpointUnauthorizedTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample(collection=self.collection).model + self.client = Client() + + def test_unauthorized_asset_post_put_patch_delete(self): + collection_name = self.collection.name + asset_name = self.asset.name + + new_asset = self.factory.create_collection_asset_sample(collection=self.collection).json + updated_asset = self.factory.create_collection_asset_sample( + collection=self.collection, name=asset_name, sample='asset-1-updated' + ).get_json('post') + + # make sure POST fails for anonymous user: + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets' + response = self.client.post(path, data=new_asset, content_type="application/json") + self.assertStatusCode(401, response, msg="Unauthorized post was permitted.") + + # make sure PUT fails for anonymous user: + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put(path, data=updated_asset, content_type="application/json") + self.assertStatusCode(401, response, msg="Unauthorized put was permitted.") + + # make sure PATCH fails for anonymous user: + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch(path, data=updated_asset, content_type="application/json") + self.assertStatusCode(401, response, msg="Unauthorized patch was permitted.") + + # make sure DELETE fails for anonymous user: + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.delete(path) + self.assertStatusCode(401, response, msg="Unauthorized del was permitted.") From cd4977acbd978921b7ded60b3b99ef02a2cada11 Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Mon, 9 Sep 2024 15:58:04 +0200 Subject: [PATCH 02/16] PB-756: API to upload collection assets Add upload endpoints for collection assets. Mostly a duplication of the already existing multipart upload endpoints. --- .../0050_collectionassetupload_and_more.py | 137 ++ app/stac_api/models.py | 76 +- app/stac_api/s3_multipart_upload.py | 46 +- app/stac_api/serializers.py | 91 +- app/stac_api/signals.py | 19 + app/stac_api/urls.py | 32 + app/stac_api/validators_view.py | 25 + app/stac_api/views.py | 236 +++ .../test_collection_asset_upload_endpoint.py | 1266 +++++++++++++++++ .../test_collection_asset_upload_model.py | 212 +++ 10 files changed, 2103 insertions(+), 37 deletions(-) create mode 100644 app/stac_api/migrations/0050_collectionassetupload_and_more.py create mode 100644 app/tests/tests_10/test_collection_asset_upload_endpoint.py create mode 100644 app/tests/tests_10/test_collection_asset_upload_model.py diff --git a/app/stac_api/migrations/0050_collectionassetupload_and_more.py b/app/stac_api/migrations/0050_collectionassetupload_and_more.py new file mode 100644 index 00000000..c933b2db --- /dev/null +++ b/app/stac_api/migrations/0050_collectionassetupload_and_more.py @@ -0,0 +1,137 @@ +# Generated by Django 5.0.8 on 2024-09-09 10:59 + +import pgtrigger.compiler +import pgtrigger.migrations + +import django.core.serializers.json +import django.core.validators +import django.db.models.deletion +from django.db import migrations +from django.db import models + +import stac_api.models + + +class Migration(migrations.Migration): + + dependencies = [ + ('stac_api', '0049_item_properties_expires'), + ] + + operations = [ + migrations.CreateModel( + name='CollectionAssetUpload', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('upload_id', models.CharField(max_length=255)), + ( + 'status', + models.CharField( + choices=[(None, ''), ('in-progress', 'In Progress'), + ('completed', 'Completed'), ('aborted', 'Aborted')], + default='in-progress', + max_length=32 + ) + ), + ( + 'number_parts', + models.IntegerField( + validators=[ + django.core.validators.MinValueValidator(1), + django.core.validators.MaxValueValidator(100) + ] + ) + ), + ( + 'md5_parts', + models.JSONField( + editable=False, encoder=django.core.serializers.json.DjangoJSONEncoder + ) + ), + ( + 'urls', + models.JSONField( + blank=True, + default=list, + encoder=django.core.serializers.json.DjangoJSONEncoder + ) + ), + ('created', models.DateTimeField(auto_now_add=True)), + ('ended', models.DateTimeField(blank=True, default=None, null=True)), + ('checksum_multihash', models.CharField(max_length=255)), + ('etag', models.CharField(default=stac_api.models.compute_etag, max_length=56)), + ( + 'update_interval', + models.IntegerField( + default=-1, + help_text= + 'Interval in seconds in which the asset data is updated.-1 means that the data is not on a regular basis updated.This field can only be set via the API.', + validators=[django.core.validators.MinValueValidator(-1)] + ) + ), + ( + 'content_encoding', + models.CharField( + blank=True, + choices=[(None, ''), ('gzip', 'Gzip'), ('br', 'Br')], + default='', + max_length=32 + ) + ), + ( + 'asset', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name='+', + to='stac_api.collectionasset' + ) + ), + ], + ), + migrations.AddConstraint( + model_name='collectionassetupload', + constraint=models.UniqueConstraint( + fields=('asset', 'upload_id'), + name='unique_asset_upload_collection_asset_upload_id' + ), + ), + migrations.AddConstraint( + model_name='collectionassetupload', + constraint=models.UniqueConstraint( + condition=models.Q(('status', 'in-progress')), + fields=('asset', 'status'), + name='unique_asset_upload_in_progress' + ), + ), + pgtrigger.migrations.AddTrigger( + model_name='collectionassetupload', + trigger=pgtrigger.compiler.Trigger( + name='add_asset_upload_trigger', + sql=pgtrigger.compiler.UpsertTriggerSql( + func= + '\n -- update AssetUpload auto variable\n NEW.etag = public.gen_random_uuid();\n\n RETURN NEW;\n ', + hash='5f51ec3c72c4d9fbe6b81d2fd881dd5228dc80bf', + operation='INSERT', + pgid='pgtrigger_add_asset_upload_trigger_8330c', + table='stac_api_collectionassetupload', + when='BEFORE' + ) + ), + ), + pgtrigger.migrations.AddTrigger( + model_name='collectionassetupload', + trigger=pgtrigger.compiler.Trigger( + name='update_asset_upload_trigger', + sql=pgtrigger.compiler.UpsertTriggerSql( + condition='WHEN (OLD.* IS DISTINCT FROM NEW.*)', + func= + '\n -- update AssetUpload auto variable\n NEW.etag = public.gen_random_uuid();\n\n RETURN NEW;\n ', + hash='0a7f1aa8f8c0bb2c413a7ce626f75c8da5bf4b6d', + operation='UPDATE', + pgid='pgtrigger_update_asset_upload_trigger_8d012', + table='stac_api_collectionassetupload', + when='BEFORE' + ) + ), + ), + ] diff --git a/app/stac_api/models.py b/app/stac_api/models.py index e192e22c..88237b4a 100644 --- a/app/stac_api/models.py +++ b/app/stac_api/models.py @@ -710,19 +710,10 @@ def get_asset_path(self): return get_collection_asset_path(self.collection, self.name) -class AssetUpload(models.Model): +class BaseAssetUpload(models.Model): class Meta: - constraints = [ - models.UniqueConstraint(fields=['asset', 'upload_id'], name='unique_together'), - # Make sure that there is only one asset upload in progress per asset - models.UniqueConstraint( - fields=['asset', 'status'], - condition=Q(status='in-progress'), - name='unique_in_progress' - ) - ] - triggers = generates_asset_upload_triggers() + abstract = True class Status(models.TextChoices): # pylint: disable=invalid-name @@ -741,7 +732,6 @@ class ContentEncoding(models.TextChoices): # using BigIntegerField as primary_key to deal with the expected large number of assets. id = models.BigAutoField(primary_key=True) - asset = models.ForeignKey(Asset, related_name='+', on_delete=models.CASCADE) upload_id = models.CharField(max_length=255, blank=False, null=False) status = models.CharField( choices=Status.choices, max_length=32, default=Status.IN_PROGRESS, blank=False, null=False @@ -777,6 +767,23 @@ class ContentEncoding(models.TextChoices): # Custom Manager that preselects the collection objects = AssetUploadManager() + +class AssetUpload(BaseAssetUpload): + + class Meta: + constraints = [ + models.UniqueConstraint(fields=['asset', 'upload_id'], name='unique_together'), + # Make sure that there is only one asset upload in progress per asset + models.UniqueConstraint( + fields=['asset', 'status'], + condition=Q(status='in-progress'), + name='unique_in_progress' + ) + ] + triggers = generates_asset_upload_triggers() + + asset = models.ForeignKey(Asset, related_name='+', on_delete=models.CASCADE) + def update_asset_from_upload(self): '''Updating the asset's file:checksum and update_interval from the upload @@ -804,6 +811,51 @@ def update_asset_from_upload(self): self.asset.save() +class CollectionAssetUpload(BaseAssetUpload): + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=['asset', 'upload_id'], + name='unique_asset_upload_collection_asset_upload_id' + ), + # Make sure that there is only one asset upload in progress per asset + models.UniqueConstraint( + fields=['asset', 'status'], + condition=Q(status='in-progress'), + name='unique_asset_upload_in_progress' + ) + ] + triggers = generates_asset_upload_triggers() + + asset = models.ForeignKey(CollectionAsset, related_name='+', on_delete=models.CASCADE) + + def update_asset_from_upload(self): + '''Updating the asset's file:checksum and update_interval from the upload + + When the upload is completed, the new file:checksum and update interval from the upload + is set to its asset parent. + ''' + logger.debug( + 'Updating asset %s file:checksum from %s to %s and update_interval from %d to %d ' + 'due to upload complete', + self.asset.name, + self.asset.checksum_multihash, + self.checksum_multihash, + self.asset.update_interval, + self.update_interval, + extra={ + 'upload_id': self.upload_id, + 'asset': self.asset.name, + 'collection': self.asset.collection.name + } + ) + + self.asset.checksum_multihash = self.checksum_multihash + self.asset.update_interval = self.update_interval + self.asset.save() + + class CountBase(models.Model): '''CountBase tables are used to help calculate the summary on a collection. This is only performant if the distinct number of values is small, e.g. we currently only have diff --git a/app/stac_api/s3_multipart_upload.py b/app/stac_api/s3_multipart_upload.py index 5b1239f1..d17f5eec 100644 --- a/app/stac_api/s3_multipart_upload.py +++ b/app/stac_api/s3_multipart_upload.py @@ -65,6 +65,24 @@ def list_multipart_uploads(self, key=None, limit=100, start=None): response.get('NextUploadIdMarker', None), ) + def log_extra(self, asset, upload_id=None, parts=None): + if hasattr(asset, 'item'): + log_extra = { + 'collection': asset.item.collection.name, + 'item': asset.item.name, + 'asset': asset.name, + } + else: + log_extra = { + 'collection': asset.collection.name, + 'asset': asset.name, + } + if upload_id is not None: + log_extra['upload_id'] = upload_id + if parts is not None: + log_extra['parts'] = parts + return log_extra + def create_multipart_upload( self, key, asset, checksum_multihash, update_interval, content_encoding ): @@ -99,11 +117,7 @@ def create_multipart_upload( CacheControl=get_s3_cache_control_value(update_interval), ContentType=asset.media_type, **extra_params, - log_extra={ - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name - } + log_extra=self.log_extra(asset) ) logger.info( 'S3 Multipart upload successfully created: upload_id=%s', @@ -148,12 +162,7 @@ def create_presigned_url(self, key, asset, part, upload_id, part_md5): Params=params, ExpiresIn=settings.AWS_PRESIGNED_URL_EXPIRES, HttpMethod='PUT', - log_extra={ - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name, - 'upload_id': upload_id - } + log_extra=self.log_extra(asset, upload_id=upload_id) ) logger.info( @@ -191,13 +200,7 @@ def complete_multipart_upload(self, key, asset, parts, upload_id): Key=key, MultipartUpload={'Parts': parts}, UploadId=upload_id, - log_extra={ - 'parts': parts, - 'upload_id': upload_id, - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name - } + log_extra=self.log_extra(asset, upload_id=upload_id, parts=parts) ) except ClientError as error: raise serializers.ValidationError(str(error)) from None @@ -271,12 +274,7 @@ def list_upload_parts(self, key, asset, upload_id, limit, offset): UploadId=upload_id, MaxParts=limit, PartNumberMarker=offset, - log_extra={ - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name, - 'upload_id': upload_id - } + log_extra=self.log_extra(asset, upload_id=upload_id) ) return response, response.get('IsTruncated', False) diff --git a/app/stac_api/serializers.py b/app/stac_api/serializers.py index c24c531e..a01a572c 100644 --- a/app/stac_api/serializers.py +++ b/app/stac_api/serializers.py @@ -16,6 +16,7 @@ from stac_api.models import AssetUpload from stac_api.models import Collection from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload from stac_api.models import CollectionLink from stac_api.models import Item from stac_api.models import ItemLink @@ -429,7 +430,7 @@ class Meta: collection = None def create(self, validated_data): - asset = validate_uniqueness_and_create(Asset, validated_data) + asset = validate_uniqueness_and_create(CollectionAsset, validated_data) return asset def update_or_create(self, look_up, validated_data): @@ -1073,3 +1074,91 @@ class Meta: parts = serializers.ListField( source='Parts', child=UploadPartSerializer(), default=list, read_only=True ) + + +class CollectionAssetUploadSerializer(NonNullModelSerializer): + + class Meta: + model = CollectionAssetUpload + list_serializer_class = AssetUploadListSerializer + fields = [ + 'upload_id', + 'status', + 'created', + 'checksum_multihash', + 'completed', + 'aborted', + 'number_parts', + 'md5_parts', + 'urls', + 'ended', + 'parts', + 'update_interval', + 'content_encoding' + ] + + checksum_multihash = serializers.CharField( + source='checksum_multihash', + max_length=255, + required=True, + allow_blank=False, + validators=[validate_checksum_multihash_sha256] + ) + md5_parts = serializers.JSONField(required=True) + update_interval = serializers.IntegerField( + required=False, allow_null=False, min_value=-1, max_value=3600, default=-1 + ) + content_encoding = serializers.CharField( + required=False, + allow_null=False, + allow_blank=False, + min_length=1, + max_length=32, + default='', + validators=[validate_content_encoding] + ) + + # write only fields + ended = serializers.DateTimeField(write_only=True, required=False) + parts = serializers.ListField( + child=UploadPartSerializer(), write_only=True, allow_empty=False, required=False + ) + + # Read only fields + upload_id = serializers.CharField(read_only=True) + created = serializers.DateTimeField(read_only=True) + urls = serializers.JSONField(read_only=True) + completed = serializers.SerializerMethodField() + aborted = serializers.SerializerMethodField() + + def validate(self, attrs): + # get partial from kwargs (if partial true and no md5 : ok, if false no md5 : error) + # Check the md5 parts length + if attrs.get('md5_parts') is not None: + validate_md5_parts(attrs['md5_parts'], attrs['number_parts']) + elif not self.partial: + raise serializers.ValidationError( + detail={'md5_parts': _('md5_parts parameter is missing')}, code='missing' + ) + return attrs + + def get_completed(self, obj): + if obj.status == CollectionAssetUpload.Status.COMPLETED: + return isoformat(obj.ended) + return None + + def get_aborted(self, obj): + if obj.status == CollectionAssetUpload.Status.ABORTED: + return isoformat(obj.ended) + return None + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + return fields diff --git a/app/stac_api/signals.py b/app/stac_api/signals.py index 2005869a..0410b70f 100644 --- a/app/stac_api/signals.py +++ b/app/stac_api/signals.py @@ -7,6 +7,7 @@ from stac_api.models import Asset from stac_api.models import AssetUpload from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload logger = logging.getLogger(__name__) @@ -30,6 +31,24 @@ def check_on_going_upload(sender, instance, **kwargs): ) +@receiver(pre_delete, sender=CollectionAssetUpload) +def check_on_going_collection_asset_upload(sender, instance, **kwargs): + if instance.status == CollectionAssetUpload.Status.IN_PROGRESS: + logger.error( + "Cannot delete collection asset %s due to upload %s which is still in progress", + instance.asset.name, + instance.upload_id, + extra={ + 'upload_id': instance.upload_id, + 'asset': instance.asset.name, + 'collection': instance.asset.collection.name + } + ) + raise ProtectedError( + f"Collection Asset {instance.asset.name} has still an upload in progress", [instance] + ) + + @receiver(pre_delete, sender=Asset) def delete_s3_asset(sender, instance, **kwargs): # The file is not automatically deleted by Django diff --git a/app/stac_api/urls.py b/app/stac_api/urls.py index 9e3328e4..4f5e2ccd 100644 --- a/app/stac_api/urls.py +++ b/app/stac_api/urls.py @@ -13,6 +13,11 @@ from stac_api.views import AssetUploadsList from stac_api.views import CollectionAssetDetail from stac_api.views import CollectionAssetsList +from stac_api.views import CollectionAssetUploadAbort +from stac_api.views import CollectionAssetUploadComplete +from stac_api.views import CollectionAssetUploadDetail +from stac_api.views import CollectionAssetUploadPartsList +from stac_api.views import CollectionAssetUploadsList from stac_api.views import CollectionDetail from stac_api.views import CollectionList from stac_api.views import ConformancePageDetail @@ -43,8 +48,35 @@ path("/assets/", include(asset_urls)) ] +collection_asset_upload_urls = [ + path( + "", CollectionAssetUploadDetail.as_view(), name='collection-asset-upload-detail' + ), + path( + "/parts", + CollectionAssetUploadPartsList.as_view(), + name='collection-asset-upload-parts-list' + ), + path( + "/complete", + CollectionAssetUploadComplete.as_view(), + name='collection-asset-upload-complete' + ), + path( + "/abort", + CollectionAssetUploadAbort.as_view(), + name='collection-asset-upload-abort' + ), +] + collection_asset_urls = [ path("", CollectionAssetDetail.as_view(), name='collection-asset-detail'), + path( + "/uploads", + CollectionAssetUploadsList.as_view(), + name='collection-asset-uploads-list' + ), + path("/uploads/", include(collection_asset_upload_urls)) ] collection_urls = [ diff --git a/app/stac_api/validators_view.py b/app/stac_api/validators_view.py index 1200f879..c8bd571d 100644 --- a/app/stac_api/validators_view.py +++ b/app/stac_api/validators_view.py @@ -9,6 +9,7 @@ from stac_api.models import Asset from stac_api.models import Collection +from stac_api.models import CollectionAsset from stac_api.models import Item logger = logging.getLogger(__name__) @@ -82,6 +83,30 @@ def validate_asset(kwargs): ) +def validate_collection_asset(kwargs): + '''Validate that the collection asset given in request kwargs exists + + Args: + kwargs: dict + request kwargs dictionary + + Raises: + Http404: when the asset doesn't exists + ''' + if not CollectionAsset.objects.filter( + name=kwargs['asset_name'], collection__name=kwargs['collection_name'] + ).exists(): + logger.error( + "The asset %s is not part of the collection %s", + kwargs['asset_name'], + kwargs['collection_name'] + ) + raise Http404( + f"The asset {kwargs['asset_name']} is not part of " + f"the collection {kwargs['collection_name']}" + ) + + def validate_renaming(serializer, original_id, id_field='name', extra_log=None): '''Validate that the object name is not different from the one defined in the data. diff --git a/app/stac_api/views.py b/app/stac_api/views.py index 72b016df..275439e3 100644 --- a/app/stac_api/views.py +++ b/app/stac_api/views.py @@ -31,6 +31,7 @@ from stac_api.models import AssetUpload from stac_api.models import Collection from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload from stac_api.models import Item from stac_api.models import LandingPage from stac_api.pagination import ExtApiPagination @@ -40,6 +41,7 @@ from stac_api.serializers import AssetUploadPartsSerializer from stac_api.serializers import AssetUploadSerializer from stac_api.serializers import CollectionAssetSerializer +from stac_api.serializers import CollectionAssetUploadSerializer from stac_api.serializers import CollectionSerializer from stac_api.serializers import ConformancePageSerializer from stac_api.serializers import ItemSerializer @@ -55,6 +57,7 @@ from stac_api.validators_serializer import ValidateSearchRequest from stac_api.validators_view import validate_asset from stac_api.validators_view import validate_collection +from stac_api.validators_view import validate_collection_asset from stac_api.validators_view import validate_item from stac_api.validators_view import validate_renaming @@ -181,6 +184,20 @@ def get_asset_upload_etag(request, *args, **kwargs): ) +def get_collection_asset_upload_etag(request, *args, **kwargs): + '''Get the ETag for a collection asset upload object + + The ETag is an UUID4 computed on each object changes + ''' + return get_etag( + CollectionAssetUpload.objects.filter( + asset__collection__name=kwargs['collection_name'], + asset__name=kwargs['asset_name'], + upload_id=kwargs['upload_id'] + ) + ) + + class LandingPageDetail(generics.RetrieveAPIView): name = 'landing-page' # this name must match the name in urls.py serializer_class = LandingPageSerializer @@ -1001,3 +1018,222 @@ def get_pagination_config(self, request): def get_paginated_response(self, data, has_next): # pylint: disable=arguments-differ return self.paginator.get_paginated_response(data, has_next) + + +class CollectionAssetUploadBase(generics.GenericAPIView): + serializer_class = CollectionAssetUploadSerializer + lookup_url_kwarg = "upload_id" + lookup_field = "upload_id" + + def get_queryset(self): + return CollectionAssetUpload.objects.filter( + asset__collection__name=self.kwargs['collection_name'], + asset__name=self.kwargs['asset_name'] + ).prefetch_related('asset') + + def get_in_progress_queryset(self): + return self.get_queryset().filter(status=CollectionAssetUpload.Status.IN_PROGRESS) + + def get_asset_or_404(self): + return get_object_or_404( + CollectionAsset.objects.all(), + name=self.kwargs['asset_name'], + collection__name=self.kwargs['collection_name'] + ) + + def _save_asset_upload(self, executor, serializer, key, asset, upload_id, urls): + try: + with transaction.atomic(): + serializer.save(asset=asset, upload_id=upload_id, urls=urls) + except IntegrityError as error: + logger.error( + 'Failed to create asset upload multipart: %s', + error, + extra={ + 'collection': asset.collection.name, 'asset': asset.name + } + ) + if bool(self.get_in_progress_queryset()): + raise UploadInProgressError( + data={"upload_id": self.get_in_progress_queryset()[0].upload_id} + ) from None + raise + + def create_multipart_upload(self, executor, serializer, validated_data, asset): + key = get_collection_asset_path(asset.collection, asset.name) + + upload_id = executor.create_multipart_upload( + key, + asset, + validated_data['checksum_multihash'], + validated_data['update_interval'], + validated_data['content_encoding'] + ) + urls = [] + sorted_md5_parts = sorted(validated_data['md5_parts'], key=itemgetter('part_number')) + + try: + for part in sorted_md5_parts: + urls.append( + executor.create_presigned_url( + key, asset, part['part_number'], upload_id, part['md5'] + ) + ) + + self._save_asset_upload(executor, serializer, key, asset, upload_id, urls) + except APIException as err: + executor.abort_multipart_upload(key, asset, upload_id) + raise + + def complete_multipart_upload(self, executor, validated_data, asset_upload, asset): + key = get_collection_asset_path(asset.collection, asset.name) + parts = validated_data.get('parts', None) + if parts is None: + raise serializers.ValidationError({ + 'parts': _("Missing required field") + }, code='missing') + if len(parts) > asset_upload.number_parts: + raise serializers.ValidationError({'parts': [_("Too many parts")]}, code='invalid') + if len(parts) < asset_upload.number_parts: + raise serializers.ValidationError({'parts': [_("Too few parts")]}, code='invalid') + if asset_upload.status != CollectionAssetUpload.Status.IN_PROGRESS: + raise UploadNotInProgressError() + executor.complete_multipart_upload(key, asset, parts, asset_upload.upload_id) + asset_upload.update_asset_from_upload() + asset_upload.status = CollectionAssetUpload.Status.COMPLETED + asset_upload.ended = utc_aware(datetime.utcnow()) + asset_upload.urls = [] + asset_upload.save() + + def abort_multipart_upload(self, executor, asset_upload, asset): + key = get_collection_asset_path(asset.collection, asset.name) + executor.abort_multipart_upload(key, asset, asset_upload.upload_id) + asset_upload.status = CollectionAssetUpload.Status.ABORTED + asset_upload.ended = utc_aware(datetime.utcnow()) + asset_upload.urls = [] + asset_upload.save() + + def list_multipart_upload_parts(self, executor, asset_upload, asset, limit, offset): + key = get_collection_asset_path(asset.collection, asset.name) + return executor.list_upload_parts(key, asset, asset_upload.upload_id, limit, offset) + + +class CollectionAssetUploadsList( + CollectionAssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin +): + + class ExternalDisallowedException(Exception): + pass + + def post(self, request, *args, **kwargs): + try: + return self.create(request, *args, **kwargs) + except self.ExternalDisallowedException as ex: + data = { + "code": 400, + "description": "Not allowed to create multipart uploads on external assets" + } + return Response(status=400, exception=True, data=data) + + def get(self, request, *args, **kwargs): + validate_collection_asset(self.kwargs) + return self.list(request, *args, **kwargs) + + def get_success_headers(self, data): + return {'Location': '/'.join([self.request.build_absolute_uri(), data['upload_id']])} + + def perform_create(self, serializer): + data = serializer.validated_data + asset = self.get_asset_or_404() + collection = asset.collection + + if asset.is_external: + raise self.ExternalDisallowedException() + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.create_multipart_upload(executor, serializer, data, asset) + + def get_queryset(self): + queryset = super().get_queryset() + + status = self.request.query_params.get('status', None) + if status: + queryset = queryset.filter_by_status(status) + + return queryset + + +class CollectionAssetUploadDetail( + CollectionAssetUploadBase, mixins.RetrieveModelMixin, views_mixins.DestroyModelMixin +): + + @etag(get_collection_asset_upload_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class CollectionAssetUploadComplete(CollectionAssetUploadBase, views_mixins.UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.complete_multipart_upload( + executor, serializer.validated_data, serializer.instance, asset + ) + + +class CollectionAssetUploadAbort(CollectionAssetUploadBase, views_mixins.UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + self.abort_multipart_upload(executor, serializer.instance, asset) + + +class CollectionAssetUploadPartsList(CollectionAssetUploadBase): + serializer_class = AssetUploadPartsSerializer + pagination_class = ExtApiPagination + + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def list(self, request, *args, **kwargs): + asset_upload = self.get_object() + limit, offset = self.get_pagination_config(request) + + collection = asset_upload.asset.collection + s3_bucket = select_s3_bucket(collection.name) + + executor = MultipartUpload(s3_bucket) + + data, has_next = self.list_multipart_upload_parts( + executor, asset_upload, asset_upload.asset, limit, offset + ) + serializer = self.get_serializer(data) + + return self.get_paginated_response(serializer.data, has_next) + + def get_pagination_config(self, request): + return self.paginator.get_pagination_config(request) + + def get_paginated_response(self, data, has_next): # pylint: disable=arguments-differ + return self.paginator.get_paginated_response(data, has_next) diff --git a/app/tests/tests_10/test_collection_asset_upload_endpoint.py b/app/tests/tests_10/test_collection_asset_upload_endpoint.py new file mode 100644 index 00000000..4945b1a2 --- /dev/null +++ b/app/tests/tests_10/test_collection_asset_upload_endpoint.py @@ -0,0 +1,1266 @@ +# pylint: disable=too-many-ancestors,too-many-lines +import gzip +import hashlib +import logging +from base64 import b64encode +from datetime import datetime +from urllib import parse + +from django.conf import settings +from django.contrib.auth import get_user_model +from django.test import Client + +from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload +from stac_api.utils import fromisoformat +from stac_api.utils import get_collection_asset_path +from stac_api.utils import get_s3_client +from stac_api.utils import get_sha256_multihash +from stac_api.utils import utc_aware + +from tests.tests_10.base_test import StacBaseTestCase +from tests.tests_10.base_test import StacBaseTransactionTestCase +from tests.tests_10.data_factory import Factory +from tests.tests_10.utils import reverse_version +from tests.utils import S3TestMixin +from tests.utils import client_login +from tests.utils import get_file_like_object +from tests.utils import mock_s3_asset_file + +logger = logging.getLogger(__name__) + +KB = 1024 +MB = 1024 * KB +GB = 1024 * MB + + +def base64_md5(data): + return b64encode(hashlib.md5(data).digest()).decode('utf-8') + + +def create_md5_parts(number_parts, offset, file_like): + return [{ + 'part_number': i + 1, 'md5': base64_md5(file_like[i * offset:(i + 1) * offset]) + } for i in range(number_parts)] + + +class CollectionAssetUploadBaseTest(StacBaseTestCase, S3TestMixin): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.client = Client() + client_login(self.client) + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample( + collection=self.collection, sample='asset-no-file' + ).model + self.maxDiff = None # pylint: disable=invalid-name + + def get_asset_upload_queryset(self): + return CollectionAssetUpload.objects.all().filter( + asset__collection__name=self.collection.name, + asset__name=self.asset.name, + ) + + def get_delete_asset_path(self): + return reverse_version( + 'collection-asset-detail', args=[self.collection.name, self.asset.name] + ) + + def get_get_multipart_uploads_path(self): + return reverse_version( + 'collection-asset-uploads-list', args=[self.collection.name, self.asset.name] + ) + + def get_create_multipart_upload_path(self): + return reverse_version( + 'collection-asset-uploads-list', args=[self.collection.name, self.asset.name] + ) + + def get_abort_multipart_upload_path(self, upload_id): + return reverse_version( + 'collection-asset-upload-abort', + args=[self.collection.name, self.asset.name, upload_id] + ) + + def get_complete_multipart_upload_path(self, upload_id): + return reverse_version( + 'collection-asset-upload-complete', + args=[self.collection.name, self.asset.name, upload_id] + ) + + def get_list_parts_path(self, upload_id): + return reverse_version( + 'collection-asset-upload-parts-list', + args=[self.collection.name, self.asset.name, upload_id] + ) + + def s3_upload_parts(self, upload_id, file_like, size, number_parts): + s3 = get_s3_client() + key = get_collection_asset_path(self.collection, self.asset.name) + parts = [] + # split the file into parts + start = 0 + offset = size // number_parts + for part in range(1, number_parts + 1): + # use the s3 client to upload the file instead of the presigned url due to the s3 + # mocking + response = s3.upload_part( + Body=file_like[start:start + offset], + Bucket=settings.AWS_SETTINGS['legacy']['S3_BUCKET_NAME'], + Key=key, + PartNumber=part, + UploadId=upload_id + ) + start += offset + parts.append({'etag': response['ETag'], 'part_number': part}) + return parts + + def check_urls_response(self, urls, number_parts): + now = utc_aware(datetime.utcnow()) + self.assertEqual(len(urls), number_parts) + for i, url in enumerate(urls): + self.assertListEqual( + list(url.keys()), ['url', 'part', 'expires'], msg='Url dictionary keys missing' + ) + self.assertEqual( + url['part'], i + 1, msg=f'Part {url["part"]} does not match the url index {i}' + ) + try: + url_parsed = parse.urlparse(url["url"]) + self.assertIn(url_parsed[0], ['http', 'https']) + except ValueError as error: + self.fail(msg=f"Invalid url {url['url']} for part {url['part']}: {error}") + try: + expires_dt = fromisoformat(url['expires']) + self.assertGreater( + expires_dt, + now, + msg=f"expires {url['expires']} for part {url['part']} is not in future" + ) + except ValueError as error: + self.fail(msg=f"Invalid expires {url['expires']} for part {url['part']}: {error}") + + def check_created_response(self, json_response): + self.assertNotIn('completed', json_response) + self.assertNotIn('aborted', json_response) + self.assertIn('upload_id', json_response) + self.assertIn('status', json_response) + self.assertIn('number_parts', json_response) + self.assertIn('file:checksum', json_response) + self.assertIn('urls', json_response) + self.assertEqual(json_response['status'], 'in-progress') + + def check_completed_response(self, json_response): + self.assertNotIn('urls', json_response) + self.assertNotIn('aborted', json_response) + self.assertIn('upload_id', json_response) + self.assertIn('status', json_response) + self.assertIn('number_parts', json_response) + self.assertIn('file:checksum', json_response) + self.assertIn('completed', json_response) + self.assertEqual(json_response['status'], 'completed') + self.assertGreater( + fromisoformat(json_response['completed']), fromisoformat(json_response['created']) + ) + + def check_aborted_response(self, json_response): + self.assertNotIn('urls', json_response) + self.assertNotIn('completed', json_response) + self.assertIn('upload_id', json_response) + self.assertIn('status', json_response) + self.assertIn('number_parts', json_response) + self.assertIn('file:checksum', json_response) + self.assertIn('aborted', json_response) + self.assertEqual(json_response['status'], 'aborted') + self.assertGreater( + fromisoformat(json_response['aborted']), fromisoformat(json_response['created']) + ) + + +class CollectionAssetUploadCreateEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_asset_upload_create_abort_multipart(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + file_like, checksum_multihash = get_file_like_object(1 * KB) + offset = 1 * KB // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + + self.check_urls_response(json_data['urls'], number_parts) + + response = self.client.post( + self.get_abort_multipart_upload_path(json_data['upload_id']), + data={}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + json_data = response.json() + self.check_aborted_response(json_data) + self.assertFalse( + self.get_asset_upload_queryset().filter( + status=CollectionAssetUpload.Status.IN_PROGRESS + ).exists(), + msg='In progress upload found' + ) + self.assertTrue( + self.get_asset_upload_queryset().filter(status=CollectionAssetUpload.Status.ABORTED + ).exists(), + msg='Aborted upload not found' + ) + # check that there is only one multipart upload on S3 + s3 = get_s3_client() + response = s3.list_multipart_uploads( + Bucket=settings.AWS_SETTINGS['legacy']['S3_BUCKET_NAME'], KeyMarker=key + ) + self.assertNotIn('Uploads', response, msg='uploads found on S3') + + def test_asset_upload_create_multipart_duplicate(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + file_like, checksum_multihash = get_file_like_object(1 * KB) + offset = 1 * KB // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + initial_upload_id = json_data['upload_id'] + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(409, response) + self.assertEqual(response.json()['description'], "Upload already in progress") + self.assertIn( + "upload_id", + response.json(), + msg="The upload id of the current upload is missing from response" + ) + self.assertEqual( + initial_upload_id, + response.json()['upload_id'], + msg="Current upload ID not matching the one from the 409 Conflict response" + ) + + self.assertEqual( + self.get_asset_upload_queryset().filter( + status=CollectionAssetUpload.Status.IN_PROGRESS + ).count(), + 1, + msg='More than one upload in progress' + ) + + # check that there is only one multipart upload on S3 + s3 = get_s3_client() + response = s3.list_multipart_uploads( + Bucket=settings.AWS_SETTINGS['legacy']['S3_BUCKET_NAME'], KeyMarker=key + ) + self.assertIn('Uploads', response, msg='Failed to retrieve the upload list from s3') + self.assertEqual(len(response['Uploads']), 1, msg='More or less uploads found on S3') + + +class CollectionAssetUploadCreateRaceConditionTest(StacBaseTransactionTestCase, S3TestMixin): + + @mock_s3_asset_file + def setUp(self): + self.username = 'user' + self.password = 'dummy-password' + get_user_model().objects.create_superuser(self.username, password=self.password) + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample( + collection=self.collection, sample='asset-no-file' + ).model + + def test_asset_upload_create_race_condition(self): + workers = 5 + + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + file_like, checksum_multihash = get_file_like_object(1 * KB) + offset = 1 * KB // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + path = reverse_version( + 'collection-asset-uploads-list', args=[self.collection.name, self.asset.name] + ) + + def asset_upload_atomic_create_test(worker): + # This method run on separate thread therefore it requires to create a new client and + # to login it for each call. + client = Client() + client.login(username=self.username, password=self.password) + return client.post( + path, + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + + # We call the POST asset upload several times in parallel with the same data to make sure + # that we don't have any race condition. + results, errors = self.run_parallel(workers, asset_upload_atomic_create_test) + + for _, response in results: + self.assertStatusCode([201, 409], response) + + ok_201 = [r for _, r in results if r.status_code == 201] + bad_409 = [r for _, r in results if r.status_code == 409] + self.assertEqual(len(ok_201), 1, msg="More than 1 parallel request was successful") + for response in bad_409: + self.assertEqual(response.json()['description'], "Upload already in progress") + + +class CollectionAssetUpload1PartEndpointTestCase(CollectionAssetUploadBaseTest): + + def upload_asset_with_dyn_cache(self, update_interval=None): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + md5_parts = [{'part_number': 1, 'md5': base64_md5(file_like)}] + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash, + 'update_interval': update_interval + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + self.assertIn('md5_parts', json_data) + self.assertEqual(json_data['md5_parts'], md5_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + return key + + def test_asset_upload_1_part_md5_integrity(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + md5_parts = [{'part_number': 1, 'md5': base64_md5(file_like)}] + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + self.assertIn('md5_parts', json_data) + self.assertEqual(json_data['md5_parts'], md5_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectCacheControl(obj, key, max_age=7200) + self.assertS3ObjectContentEncoding(obj, key, None) + + def test_asset_upload_dyn_cache(self): + key = self.upload_asset_with_dyn_cache(update_interval=600) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectCacheControl(obj, key, max_age=8) + + def test_asset_upload_no_cache(self): + key = self.upload_asset_with_dyn_cache(update_interval=5) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectCacheControl(obj, key, no_cache=True) + + def test_asset_upload_no_content_encoding(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + md5_parts = [{'part_number': 1, 'md5': base64_md5(file_like)}] + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + self.assertIn('md5_parts', json_data) + self.assertEqual(json_data['md5_parts'], md5_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectContentEncoding(obj, key, None) + + def test_asset_upload_gzip(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * MB + file_like, checksum_multihash = get_file_like_object(size) + file_like_compress = gzip.compress(file_like) + size_compress = len(file_like_compress) + md5_parts = [{'part_number': 1, 'md5': base64_md5(file_like_compress)}] + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash, + 'content_encoding': 'gzip' + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + self.assertIn('md5_parts', json_data) + self.assertEqual(json_data['md5_parts'], md5_parts) + + parts = self.s3_upload_parts( + json_data['upload_id'], file_like_compress, size_compress, number_parts + ) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectContentEncoding(obj, key, encoding='gzip') + + +class CollectionAssetUpload2PartEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_asset_upload_2_parts_md5_integrity(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + size = 10 * MB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + self.assertS3ObjectExists(key) + + +class CollectionAssetUploadInvalidEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_asset_upload_invalid_content_encoding(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + size = 10 * MB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash, + 'content_encoding': 'hello world' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + {'content_encoding': ['Invalid encoding "hello world": must be one of ' + '"br, gzip"']} + ) + + def test_asset_upload_1_part_no_md5(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'md5_parts': ['This field is required.']}) + + def test_asset_upload_2_parts_no_md5(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + size = 10 * MB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'md5_parts': ['This field is required.']}) + + def test_asset_upload_create_empty_payload(self): + response = self.client.post( + self.get_create_multipart_upload_path(), data={}, content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'file:checksum': ['This field is required.'], + 'number_parts': ['This field is required.'], + 'md5_parts': ['This field is required.'] + } + ) + + def test_asset_upload_create_invalid_data(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 0, + "file:checksum": 'abcdef', + "md5_parts": [{ + "part_number": '0', "md5": 'abcdef' + }] + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'file:checksum': ['Invalid multihash value; Invalid varint provided'], + 'number_parts': ['Ensure this value is greater than or equal to 1.'] + } + ) + + def test_asset_upload_create_too_many_parts(self): + + number_parts = 101 + md5_parts = [{'part_number': i + 1, 'md5': 'abcdef'} for i in range(number_parts)] + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 101, "file:checksum": 'abcdef', 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'file:checksum': ['Invalid multihash value; Invalid varint provided'], + 'number_parts': ['Ensure this value is less than or equal to 100.'] + } + ) + + def test_asset_upload_create_empty_md5_parts(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 2, + "md5_parts": [], + "file:checksum": + '12200ADEC47F803A8CF1055ED36750B3BA573C79A3AF7DA6D6F5A2AED03EA16AF3BC' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'non_field_errors': [ + 'Missing, too many or duplicate part_number in md5_parts field list: ' + 'list should have 2 item(s).' + ] + } + ) + + def test_asset_upload_create_duplicate_md5_parts(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 3, + "md5_parts": [{ + 'part_number': 1, 'md5': 'asdf' + }, { + 'part_number': 1, 'md5': 'asdf' + }, { + 'part_number': 2, 'md5': 'asdf' + }], + "file:checksum": + '12200ADEC47F803A8CF1055ED36750B3BA573C79A3AF7DA6D6F5A2AED03EA16AF3BC' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'non_field_errors': [ + 'Missing, too many or duplicate part_number in md5_parts field list: ' + 'list should have 3 item(s).' + ] + } + ) + + def test_asset_upload_create_too_many_md5_parts(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 2, + "md5_parts": [{ + 'part_number': 1, 'md5': 'asdf' + }, { + 'part_number': 2, 'md5': 'asdf' + }, { + 'part_number': 3, 'md5': 'asdf' + }], + "file:checksum": + '12200ADEC47F803A8CF1055ED36750B3BA573C79A3AF7DA6D6F5A2AED03EA16AF3BC' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'non_field_errors': [ + 'Missing, too many or duplicate part_number in md5_parts field list: ' + 'list should have 2 item(s).' + ] + } + ) + + def test_asset_upload_create_md5_parts_missing_part_number(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 2, + "md5_parts": [ + { + 'part_number': 1, 'md5': 'asdf' + }, + { + 'md5': 'asdf' + }, + ], + "file:checksum": + '12200ADEC47F803A8CF1055ED36750B3BA573C79A3AF7DA6D6F5A2AED03EA16AF3BC' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + {'non_field_errors': ['Invalid md5_parts[1] value: part_number field missing']} + ) + + def test_asset_upload_2_parts_too_small(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + size = 1 * KB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + [ + 'An error occurred (EntityTooSmall) when calling the CompleteMultipartUpload ' + 'operation: Your proposed upload is smaller than the minimum allowed object size.' + ] + ) + self.assertS3ObjectNotExists(key) + + def test_asset_upload_1_parts_invalid_etag(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': [{ + 'etag': 'dummy', 'part_number': 1 + }]}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + [ + 'An error occurred (InvalidPart) when calling the CompleteMultipartUpload ' + 'operation: One or more of the specified parts could not be found. The part ' + 'might not have been uploaded, or the specified entity tag might not have ' + "matched the part's entity tag." + ] + ) + self.assertS3ObjectNotExists(key) + + def test_asset_upload_1_parts_too_many_parts_in_complete(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + parts.append({'etag': 'dummy', 'number_part': 2}) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'parts': ['Too many parts']}) + self.assertS3ObjectNotExists(key) + + def test_asset_upload_2_parts_incomplete_upload(self): + number_parts = 2 + size = 10 * MB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size // 2, 1) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'parts': ['Too few parts']}) + + def test_asset_upload_1_parts_invalid_complete(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'parts': 'Missing required field'}) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': []}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'parts': ['This list may not be empty.']}) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': ["dummy-etag"]}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'parts': { + '0': { + 'non_field_errors': + ['Invalid data. Expected a dictionary, ' + 'but got str.'] + } + } + } + ) + self.assertS3ObjectNotExists(key) + + def test_asset_upload_1_parts_duplicate_complete(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(409, response) + self.assertEqual(response.json()['code'], 409) + self.assertEqual(response.json()['description'], 'No upload in progress') + + +class CollectionAssetUploadDeleteInProgressEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_delete_asset_upload_in_progress(self): + number_parts = 2 + size = 10 * MB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + upload_id = response.json()['upload_id'] + + response = self.client.delete(self.get_delete_asset_path()) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + ['Collection Asset collection-asset-1.tiff has still an upload in progress'] + ) + + self.assertTrue( + CollectionAsset.objects.all().filter( + collection__name=self.collection.name, name=self.asset.name + ).exists(), + msg='Collection Asset has been deleted' + ) + + response = self.client.post(self.get_abort_multipart_upload_path(upload_id)) + self.assertStatusCode(200, response) + + response = self.client.delete(self.get_delete_asset_path()) + self.assertStatusCode(200, response) + + self.assertFalse( + CollectionAsset.objects.all().filter( + collection__name=self.collection.name, name=self.asset.name + ).exists(), + msg='Collection Asset has not been deleted' + ) + + +class GetCollectionAssetUploadsEndpointTestCase(CollectionAssetUploadBaseTest): + + def create_dummies_uploads(self): + # Create some asset uploads + for i in range(1, 4): + CollectionAssetUpload.objects.create( + asset=self.asset, + upload_id=f'upload-{i}', + status=CollectionAssetUpload.Status.ABORTED, + checksum_multihash=get_sha256_multihash(b'upload-%d' % i), + number_parts=2, + ended=utc_aware(datetime.utcnow()), + md5_parts=[] + ) + for i in range(4, 8): + CollectionAssetUpload.objects.create( + asset=self.asset, + upload_id=f'upload-{i}', + status=CollectionAssetUpload.Status.COMPLETED, + checksum_multihash=get_sha256_multihash(b'upload-%d' % i), + number_parts=2, + ended=utc_aware(datetime.utcnow()), + md5_parts=[] + ) + CollectionAssetUpload.objects.create( + asset=self.asset, + upload_id='upload-8', + status=CollectionAssetUpload.Status.IN_PROGRESS, + checksum_multihash=get_sha256_multihash(b'upload-8'), + number_parts=2, + md5_parts=[] + ) + self.maxDiff = None # pylint: disable=invalid-name + + def test_get_asset_uploads(self): + self.create_dummies_uploads() + response = self.client.get(self.get_get_multipart_uploads_path()) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('links', json_data) + self.assertEqual(json_data['links'], []) + self.assertIn('uploads', json_data) + self.assertEqual(len(json_data['uploads']), self.get_asset_upload_queryset().count()) + self.assertEqual( + [ + 'upload_id', + 'status', + 'created', + 'aborted', + 'number_parts', + 'update_interval', + 'content_encoding', + 'file:checksum' + ], + list(json_data['uploads'][0].keys()), + ) + self.assertEqual( + [ + 'upload_id', + 'status', + 'created', + 'completed', + 'number_parts', + 'update_interval', + 'content_encoding', + 'file:checksum' + ], + list(json_data['uploads'][4].keys()), + ) + self.assertEqual( + [ + 'upload_id', + 'status', + 'created', + 'number_parts', + 'update_interval', + 'content_encoding', + 'file:checksum' + ], + list(json_data['uploads'][7].keys()), + ) + + def test_get_asset_uploads_with_content_encoding(self): + CollectionAssetUpload.objects.create( + asset=self.asset, + upload_id='upload-content-encoding', + status=CollectionAssetUpload.Status.COMPLETED, + checksum_multihash=get_sha256_multihash(b'upload-content-encoding'), + number_parts=2, + ended=utc_aware(datetime.utcnow()), + md5_parts=[], + content_encoding='gzip' + ) + response = self.client.get(self.get_get_multipart_uploads_path()) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('links', json_data) + self.assertEqual(json_data['links'], []) + self.assertIn('uploads', json_data) + self.assertEqual(len(json_data['uploads']), self.get_asset_upload_queryset().count()) + self.assertEqual( + [ + 'upload_id', + 'status', + 'created', + 'completed', + 'number_parts', + 'update_interval', + 'content_encoding', + 'file:checksum' + ], + list(json_data['uploads'][0].keys()), + ) + self.assertEqual('gzip', json_data['uploads'][0]['content_encoding']) + + def test_get_asset_uploads_status_query(self): + response = self.client.get( + self.get_get_multipart_uploads_path(), {'status': CollectionAssetUpload.Status.ABORTED} + ) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('uploads', json_data) + self.assertGreater(len(json_data), 1) + self.assertEqual( + len(json_data['uploads']), + self.get_asset_upload_queryset().filter(status=CollectionAssetUpload.Status.ABORTED + ).count(), + ) + for upload in json_data['uploads']: + self.assertEqual(upload['status'], CollectionAssetUpload.Status.ABORTED) + + +class CollectionAssetUploadListPartsEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_asset_upload_list_parts(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 4 + size = 5 * MB * number_parts + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + upload_id = json_data['upload_id'] + self.check_urls_response(json_data['urls'], number_parts) + + # List the uploaded parts should be empty + response = self.client.get(self.get_list_parts_path(upload_id)) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('links', json_data, msg='missing required field in list parts response') + self.assertIn('parts', json_data, msg='missing required field in list parts response') + self.assertEqual(len(json_data['parts']), 0, msg='parts should be empty') + + # upload all the parts + parts = self.s3_upload_parts(upload_id, file_like, size, number_parts) + + # List the uploaded parts should have 4 parts + response = self.client.get(self.get_list_parts_path(upload_id)) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('links', json_data, msg='missing required field in list parts response') + self.assertIn('parts', json_data, msg='missing required field in list parts response') + self.assertEqual(len(json_data['parts']), number_parts) + for part in json_data['parts']: + self.assertIn('etag', part) + self.assertIn('modified', part) + self.assertIn('size', part) + self.assertIn('part_number', part) + + # Unfortunately moto doesn't support yet the MaxParts + # (see https://github.com/spulec/moto/issues/2680) + # Test the list parts pagination + # response = self.client.get(self.get_list_parts_path(upload_id), {'limit': 2}) + # self.assertStatusCode(200, response) + # json_data = response.json() + # self.assertIn('links', json_data, msg='missing required field in list parts response') + # self.assertIn('parts', json_data, msg='missing required field in list parts response') + # self.assertEqual(len(json_data['parts']), number_parts) + # for part in json_data['parts']: + # self.assertIn('etag', part) + # self.assertIn('modified', part) + # self.assertIn('size', part) + # self.assertIn('part_number', part) + + # Complete the upload + response = self.client.post( + self.get_complete_multipart_upload_path(upload_id), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.assertS3ObjectExists(key) diff --git a/app/tests/tests_10/test_collection_asset_upload_model.py b/app/tests/tests_10/test_collection_asset_upload_model.py new file mode 100644 index 00000000..9ea016aa --- /dev/null +++ b/app/tests/tests_10/test_collection_asset_upload_model.py @@ -0,0 +1,212 @@ +import logging +from datetime import datetime + +from django.core.exceptions import ValidationError +from django.db.models import ProtectedError +from django.test import TestCase +from django.test import TransactionTestCase + +from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload +from stac_api.utils import get_sha256_multihash +from stac_api.utils import utc_aware + +from tests.tests_10.data_factory import Factory +from tests.utils import mock_s3_asset_file + +logger = logging.getLogger(__name__) + + +class CollectionAssetUploadTestCaseMixin: + + def create_asset_upload(self, asset, upload_id, **kwargs): + asset_upload = CollectionAssetUpload( + asset=asset, + upload_id=upload_id, + checksum_multihash=get_sha256_multihash(b'Test'), + number_parts=1, + md5_parts=["this is an md5 value"], + **kwargs + ) + asset_upload.full_clean() + asset_upload.save() + self.assertEqual( + asset_upload, + CollectionAssetUpload.objects.get( + upload_id=upload_id, + asset__name=asset.name, + asset__collection__name=asset.collection.name + ) + ) + return asset_upload + + def update_asset_upload(self, asset_upload, **kwargs): + for kwarg, value in kwargs.items(): + setattr(asset_upload, kwarg, value) + asset_upload.full_clean() + asset_upload.save() + asset_upload.refresh_from_db() + self.assertEqual( + asset_upload, + CollectionAssetUpload.objects.get( + upload_id=asset_upload.upload_id, asset__name=asset_upload.asset.name + ) + ) + return asset_upload + + def check_etag(self, etag): + self.assertIsInstance(etag, str, msg="Etag must be a string") + self.assertNotEqual(etag, '', msg='Etag should not be empty') + + +class CollectionAssetUploadModelTestCase(TestCase, CollectionAssetUploadTestCaseMixin): + + @classmethod + @mock_s3_asset_file + def setUpTestData(cls): + cls.factory = Factory() + cls.collection = cls.factory.create_collection_sample().model + cls.asset_1 = cls.factory.create_collection_asset_sample(collection=cls.collection).model + cls.asset_2 = cls.factory.create_collection_asset_sample(collection=cls.collection).model + + def test_create_asset_upload_default(self): + asset_upload = self.create_asset_upload(self.asset_1, 'default-upload') + self.assertEqual(asset_upload.urls, [], msg="Wrong default value") + self.assertEqual(asset_upload.ended, None, msg="Wrong default value") + self.assertAlmostEqual( + utc_aware(datetime.utcnow()).timestamp(), + asset_upload.created.timestamp(), # pylint: disable=no-member + delta=1, + msg="Wrong default value" + ) + + def test_unique_constraint(self): + # Check that asset upload is unique in collection/asset + # therefore the following asset upload should be ok + # collection-1/asset-1/default-upload + # collection-2/asset-1/default-upload + collection_2 = self.factory.create_collection_sample().model + asset_2 = self.factory.create_collection_asset_sample( + collection_2, name=self.asset_1.name + ).model + asset_upload_1 = self.create_asset_upload(self.asset_1, 'default-upload') + asset_upload_2 = self.create_asset_upload(asset_2, 'default-upload') + self.assertEqual(asset_upload_1.upload_id, asset_upload_2.upload_id) + self.assertEqual(asset_upload_1.asset.name, asset_upload_2.asset.name) + self.assertNotEqual( + asset_upload_1.asset.collection.name, asset_upload_2.asset.collection.name + ) + # But duplicate path are not allowed + with self.assertRaises(ValidationError, msg="Existing asset upload could be re-created."): + asset_upload_3 = self.create_asset_upload(self.asset_1, 'default-upload') + + def test_create_asset_upload_duplicate_in_progress(self): + # create a first upload on asset 1 + asset_upload_1 = self.create_asset_upload(self.asset_1, '1st-upload') + + # create a first upload on asset 2 + asset_upload_2 = self.create_asset_upload(self.asset_2, '1st-upload') + + # create a second upload on asset 1 should not be allowed. + with self.assertRaises( + ValidationError, msg="Existing asset upload already in progress could be re-created." + ): + asset_upload_3 = self.create_asset_upload(self.asset_1, '2nd-upload') + + def test_asset_upload_etag(self): + asset_upload = self.create_asset_upload(self.asset_1, 'default-upload') + original_etag = asset_upload.etag + self.check_etag(original_etag) + asset_upload = self.update_asset_upload( + asset_upload, status=CollectionAssetUpload.Status.ABORTED + ) + self.check_etag(asset_upload.etag) + self.assertNotEqual(asset_upload.etag, original_etag, msg='Etag was not updated') + + def test_asset_upload_invalid_number_parts(self): + with self.assertRaises(ValidationError): + asset_upload = CollectionAssetUpload( + asset=self.asset_1, + upload_id='my-upload-id', + checksum_multihash=get_sha256_multihash(b'Test'), + number_parts=-1, + md5_parts=['fake_md5'] + ) + asset_upload.full_clean() + asset_upload.save() + + +class CollectionAssetUploadDeleteProtectModelTestCase( + TransactionTestCase, CollectionAssetUploadTestCaseMixin +): + + @mock_s3_asset_file + def setUp(self): + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample(collection=self.collection).model + + def test_delete_asset_upload(self): + upload_id = 'upload-in-progress' + asset_upload = self.create_asset_upload(self.asset, upload_id) + + with self.assertRaises(ProtectedError, msg="Deleting an upload in progress not allowed"): + asset_upload.delete() + + asset_upload = self.update_asset_upload( + asset_upload, + status=CollectionAssetUpload.Status.COMPLETED, + ended=utc_aware(datetime.utcnow()) + ) + + asset_upload.delete() + self.assertFalse( + CollectionAssetUpload.objects.all().filter( + upload_id=upload_id, asset__name=self.asset.name + ).exists() + ) + + def test_delete_asset_with_upload_in_progress(self): + asset_upload_1 = self.create_asset_upload(self.asset, 'upload-in-progress') + asset_upload_2 = self.create_asset_upload( + self.asset, + 'upload-completed', + status=CollectionAssetUpload.Status.COMPLETED, + ended=utc_aware(datetime.utcnow()) + ) + asset_upload_3 = self.create_asset_upload( + self.asset, + 'upload-aborted', + status=CollectionAssetUpload.Status.ABORTED, + ended=utc_aware(datetime.utcnow()) + ) + asset_upload_4 = self.create_asset_upload( + self.asset, + 'upload-aborted-2', + status=CollectionAssetUpload.Status.ABORTED, + ended=utc_aware(datetime.utcnow()) + ) + + # Try to delete parent asset + with self.assertRaises(ValidationError): + self.asset.delete() + self.assertEqual(4, len(list(CollectionAssetUpload.objects.all()))) + self.assertTrue( + CollectionAsset.objects.all().filter( + name=self.asset.name, collection__name=self.collection.name + ).exists() + ) + + self.update_asset_upload( + asset_upload_1, + status=CollectionAssetUpload.Status.ABORTED, + ended=utc_aware(datetime.utcnow()) + ) + + self.asset.delete() + self.assertEqual(0, len(list(CollectionAssetUpload.objects.all()))) + self.assertFalse( + CollectionAsset.objects.all().filter( + name=self.asset.name, collection__name=self.collection.name + ).exists() + ) From 77f042bb6e5e43c263ce529f436f776753e613fe Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Mon, 9 Sep 2024 17:19:49 +0200 Subject: [PATCH 03/16] PB-756: Refactor asset upload views Use shared base class for asset upload and collection asset upload. --- app/stac_api/views.py | 165 +++++++++++++----------------------------- 1 file changed, 51 insertions(+), 114 deletions(-) diff --git a/app/stac_api/views.py b/app/stac_api/views.py index 275439e3..f8ea2350 100644 --- a/app/stac_api/views.py +++ b/app/stac_api/views.py @@ -29,6 +29,7 @@ from stac_api.exceptions import UploadNotInProgressError from stac_api.models import Asset from stac_api.models import AssetUpload +from stac_api.models import BaseAssetUpload from stac_api.models import Collection from stac_api.models import CollectionAsset from stac_api.models import CollectionAssetUpload @@ -797,28 +798,32 @@ def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) -class AssetUploadBase(generics.GenericAPIView): - serializer_class = AssetUploadSerializer +class SharedAssetUploadBase(generics.GenericAPIView): + """SharedAssetUploadBase provides a base view for asset uploads and collection asset uploads. + """ lookup_url_kwarg = "upload_id" lookup_field = "upload_id" def get_queryset(self): - return AssetUpload.objects.filter( - asset__item__collection__name=self.kwargs['collection_name'], - asset__item__name=self.kwargs['item_name'], - asset__name=self.kwargs['asset_name'] - ).prefetch_related('asset') + raise NotImplementedError("get_queryset() not implemented") def get_in_progress_queryset(self): - return self.get_queryset().filter(status=AssetUpload.Status.IN_PROGRESS) + return self.get_queryset().filter(status=BaseAssetUpload.Status.IN_PROGRESS) def get_asset_or_404(self): - return get_object_or_404( - Asset.objects.all(), - name=self.kwargs['asset_name'], - item__name=self.kwargs['item_name'], - item__collection__name=self.kwargs['collection_name'] - ) + raise NotImplementedError("get_asset_or_404() not implemented") + + def log_extra(self, asset): + if isinstance(asset, CollectionAsset): + return {'collection': asset.collection.name, 'asset': asset.name} + return { + 'collection': asset.item.collection.name, 'item': asset.item.name, 'asset': asset.name + } + + def get_path(self, asset): + if isinstance(asset, CollectionAsset): + return get_collection_asset_path(asset.collection, asset.name) + return get_asset_path(asset.item, asset.name) def _save_asset_upload(self, executor, serializer, key, asset, upload_id, urls): try: @@ -826,13 +831,7 @@ def _save_asset_upload(self, executor, serializer, key, asset, upload_id, urls): serializer.save(asset=asset, upload_id=upload_id, urls=urls) except IntegrityError as error: logger.error( - 'Failed to create asset upload multipart: %s', - error, - extra={ - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name - } + 'Failed to create asset upload multipart: %s', error, extra=self.log_extra(asset) ) if bool(self.get_in_progress_queryset()): raise UploadInProgressError( @@ -841,7 +840,7 @@ def _save_asset_upload(self, executor, serializer, key, asset, upload_id, urls): raise def create_multipart_upload(self, executor, serializer, validated_data, asset): - key = get_asset_path(asset.item, asset.name) + key = self.get_path(asset) upload_id = executor.create_multipart_upload( key, @@ -867,7 +866,7 @@ def create_multipart_upload(self, executor, serializer, validated_data, asset): raise def complete_multipart_upload(self, executor, validated_data, asset_upload, asset): - key = get_asset_path(asset.item, asset.name) + key = self.get_path(asset) parts = validated_data.get('parts', None) if parts is None: raise serializers.ValidationError({ @@ -877,28 +876,49 @@ def complete_multipart_upload(self, executor, validated_data, asset_upload, asse raise serializers.ValidationError({'parts': [_("Too many parts")]}, code='invalid') if len(parts) < asset_upload.number_parts: raise serializers.ValidationError({'parts': [_("Too few parts")]}, code='invalid') - if asset_upload.status != AssetUpload.Status.IN_PROGRESS: + if asset_upload.status != BaseAssetUpload.Status.IN_PROGRESS: raise UploadNotInProgressError() executor.complete_multipart_upload(key, asset, parts, asset_upload.upload_id) asset_upload.update_asset_from_upload() - asset_upload.status = AssetUpload.Status.COMPLETED + asset_upload.status = BaseAssetUpload.Status.COMPLETED asset_upload.ended = utc_aware(datetime.utcnow()) asset_upload.urls = [] asset_upload.save() def abort_multipart_upload(self, executor, asset_upload, asset): - key = get_asset_path(asset.item, asset.name) + key = self.get_path(asset) executor.abort_multipart_upload(key, asset, asset_upload.upload_id) - asset_upload.status = AssetUpload.Status.ABORTED + asset_upload.status = BaseAssetUpload.Status.ABORTED asset_upload.ended = utc_aware(datetime.utcnow()) asset_upload.urls = [] asset_upload.save() def list_multipart_upload_parts(self, executor, asset_upload, asset, limit, offset): - key = get_asset_path(asset.item, asset.name) + key = self.get_path(asset) return executor.list_upload_parts(key, asset, asset_upload.upload_id, limit, offset) +class AssetUploadBase(SharedAssetUploadBase): + """AssetUploadBase is the base for all asset (not collection asset) upload views. + """ + serializer_class = AssetUploadSerializer + + def get_queryset(self): + return AssetUpload.objects.filter( + asset__item__collection__name=self.kwargs['collection_name'], + asset__item__name=self.kwargs['item_name'], + asset__name=self.kwargs['asset_name'] + ).prefetch_related('asset') + + def get_asset_or_404(self): + return get_object_or_404( + Asset.objects.all(), + name=self.kwargs['asset_name'], + item__name=self.kwargs['item_name'], + item__collection__name=self.kwargs['collection_name'] + ) + + class AssetUploadsList(AssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin): class ExternalDisallowedException(Exception): @@ -950,10 +970,6 @@ class AssetUploadDetail(AssetUploadBase, mixins.RetrieveModelMixin, views_mixins def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) - # @etag(get_asset_upload_etag) - # def delete(self, request, *args, **kwargs): - # return self.destroy(request, *args, **kwargs) - class AssetUploadComplete(AssetUploadBase, views_mixins.UpdateInsertModelMixin): @@ -1020,10 +1036,10 @@ def get_paginated_response(self, data, has_next): # pylint: disable=arguments-d return self.paginator.get_paginated_response(data, has_next) -class CollectionAssetUploadBase(generics.GenericAPIView): +class CollectionAssetUploadBase(SharedAssetUploadBase): + """CollectionAssetUploadBase is the base for all collection asset upload views. + """ serializer_class = CollectionAssetUploadSerializer - lookup_url_kwarg = "upload_id" - lookup_field = "upload_id" def get_queryset(self): return CollectionAssetUpload.objects.filter( @@ -1031,9 +1047,6 @@ def get_queryset(self): asset__name=self.kwargs['asset_name'] ).prefetch_related('asset') - def get_in_progress_queryset(self): - return self.get_queryset().filter(status=CollectionAssetUpload.Status.IN_PROGRESS) - def get_asset_or_404(self): return get_object_or_404( CollectionAsset.objects.all(), @@ -1041,82 +1054,6 @@ def get_asset_or_404(self): collection__name=self.kwargs['collection_name'] ) - def _save_asset_upload(self, executor, serializer, key, asset, upload_id, urls): - try: - with transaction.atomic(): - serializer.save(asset=asset, upload_id=upload_id, urls=urls) - except IntegrityError as error: - logger.error( - 'Failed to create asset upload multipart: %s', - error, - extra={ - 'collection': asset.collection.name, 'asset': asset.name - } - ) - if bool(self.get_in_progress_queryset()): - raise UploadInProgressError( - data={"upload_id": self.get_in_progress_queryset()[0].upload_id} - ) from None - raise - - def create_multipart_upload(self, executor, serializer, validated_data, asset): - key = get_collection_asset_path(asset.collection, asset.name) - - upload_id = executor.create_multipart_upload( - key, - asset, - validated_data['checksum_multihash'], - validated_data['update_interval'], - validated_data['content_encoding'] - ) - urls = [] - sorted_md5_parts = sorted(validated_data['md5_parts'], key=itemgetter('part_number')) - - try: - for part in sorted_md5_parts: - urls.append( - executor.create_presigned_url( - key, asset, part['part_number'], upload_id, part['md5'] - ) - ) - - self._save_asset_upload(executor, serializer, key, asset, upload_id, urls) - except APIException as err: - executor.abort_multipart_upload(key, asset, upload_id) - raise - - def complete_multipart_upload(self, executor, validated_data, asset_upload, asset): - key = get_collection_asset_path(asset.collection, asset.name) - parts = validated_data.get('parts', None) - if parts is None: - raise serializers.ValidationError({ - 'parts': _("Missing required field") - }, code='missing') - if len(parts) > asset_upload.number_parts: - raise serializers.ValidationError({'parts': [_("Too many parts")]}, code='invalid') - if len(parts) < asset_upload.number_parts: - raise serializers.ValidationError({'parts': [_("Too few parts")]}, code='invalid') - if asset_upload.status != CollectionAssetUpload.Status.IN_PROGRESS: - raise UploadNotInProgressError() - executor.complete_multipart_upload(key, asset, parts, asset_upload.upload_id) - asset_upload.update_asset_from_upload() - asset_upload.status = CollectionAssetUpload.Status.COMPLETED - asset_upload.ended = utc_aware(datetime.utcnow()) - asset_upload.urls = [] - asset_upload.save() - - def abort_multipart_upload(self, executor, asset_upload, asset): - key = get_collection_asset_path(asset.collection, asset.name) - executor.abort_multipart_upload(key, asset, asset_upload.upload_id) - asset_upload.status = CollectionAssetUpload.Status.ABORTED - asset_upload.ended = utc_aware(datetime.utcnow()) - asset_upload.urls = [] - asset_upload.save() - - def list_multipart_upload_parts(self, executor, asset_upload, asset, limit, offset): - key = get_collection_asset_path(asset.collection, asset.name) - return executor.list_upload_parts(key, asset, asset_upload.upload_id, limit, offset) - class CollectionAssetUploadsList( CollectionAssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin From 9105c825d2d6afd1af0d7f3c618ae0a0b8e8107e Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Tue, 10 Sep 2024 15:44:58 +0200 Subject: [PATCH 04/16] PB-756: Fix comments for collection assets Add type hint for new function log_extra. --- .../0050_collectionassetupload_and_more.py | 14 ++++++++++++-- app/stac_api/models.py | 10 +++++----- app/stac_api/s3_multipart_upload.py | 6 ++++-- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/app/stac_api/migrations/0050_collectionassetupload_and_more.py b/app/stac_api/migrations/0050_collectionassetupload_and_more.py index c933b2db..1238a62a 100644 --- a/app/stac_api/migrations/0050_collectionassetupload_and_more.py +++ b/app/stac_api/migrations/0050_collectionassetupload_and_more.py @@ -1,4 +1,4 @@ -# Generated by Django 5.0.8 on 2024-09-09 10:59 +# Generated by Django 5.0.8 on 2024-09-10 12:45 import pgtrigger.compiler import pgtrigger.migrations @@ -19,6 +19,16 @@ class Migration(migrations.Migration): ] operations = [ + migrations.AlterField( + model_name='assetupload', + name='update_interval', + field=models.IntegerField( + default=-1, + help_text= + 'Interval in seconds in which the asset data is updated. -1 means that the data is not on a regular basis updated. This field can only be set via the API.', + validators=[django.core.validators.MinValueValidator(-1)] + ), + ), migrations.CreateModel( name='CollectionAssetUpload', fields=[ @@ -65,7 +75,7 @@ class Migration(migrations.Migration): models.IntegerField( default=-1, help_text= - 'Interval in seconds in which the asset data is updated.-1 means that the data is not on a regular basis updated.This field can only be set via the API.', + 'Interval in seconds in which the asset data is updated. -1 means that the data is not on a regular basis updated. This field can only be set via the API.', validators=[django.core.validators.MinValueValidator(-1)] ) ), diff --git a/app/stac_api/models.py b/app/stac_api/models.py index 88237b4a..c4ef6de3 100644 --- a/app/stac_api/models.py +++ b/app/stac_api/models.py @@ -755,8 +755,8 @@ class ContentEncoding(models.TextChoices): null=False, blank=False, validators=[MinValueValidator(-1)], - help_text="Interval in seconds in which the asset data is updated." - "-1 means that the data is not on a regular basis updated." + help_text="Interval in seconds in which the asset data is updated. " + "-1 means that the data is not on a regular basis updated. " "This field can only be set via the API." ) @@ -819,7 +819,7 @@ class Meta: fields=['asset', 'upload_id'], name='unique_asset_upload_collection_asset_upload_id' ), - # Make sure that there is only one asset upload in progress per asset + # Make sure that there is only one upload in progress per collection asset models.UniqueConstraint( fields=['asset', 'status'], condition=Q(status='in-progress'), @@ -837,8 +837,8 @@ def update_asset_from_upload(self): is set to its asset parent. ''' logger.debug( - 'Updating asset %s file:checksum from %s to %s and update_interval from %d to %d ' - 'due to upload complete', + 'Updating collection asset %s file:checksum from %s to %s and update_interval ' + 'from %d to %d due to upload complete', self.asset.name, self.asset.checksum_multihash, self.checksum_multihash, diff --git a/app/stac_api/s3_multipart_upload.py b/app/stac_api/s3_multipart_upload.py index d17f5eec..0c5a10c8 100644 --- a/app/stac_api/s3_multipart_upload.py +++ b/app/stac_api/s3_multipart_upload.py @@ -12,6 +12,8 @@ from rest_framework import serializers from stac_api.exceptions import UploadNotInProgressError +from stac_api.models import Asset +from stac_api.models import CollectionAsset from stac_api.utils import AVAILABLE_S3_BUCKETS from stac_api.utils import get_s3_cache_control_value from stac_api.utils import get_s3_client @@ -65,8 +67,8 @@ def list_multipart_uploads(self, key=None, limit=100, start=None): response.get('NextUploadIdMarker', None), ) - def log_extra(self, asset, upload_id=None, parts=None): - if hasattr(asset, 'item'): + def log_extra(self, asset: Asset | CollectionAsset, upload_id=None, parts=None): + if isinstance(asset, Asset): log_extra = { 'collection': asset.item.collection.name, 'item': asset.item.name, From eb93e717228e03683d953a35a45e426707bb7108 Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Tue, 10 Sep 2024 17:11:02 +0200 Subject: [PATCH 05/16] Move views files to subfolder --- app/config/urls.py | 10 +++--- app/stac_api/urls.py | 44 ++++++++++++------------ app/stac_api/views/__init__.py | 0 app/stac_api/{ => views}/views.py | 2 +- app/stac_api/{ => views}/views_mixins.py | 0 app/stac_api/{ => views}/views_test.py | 8 ++--- 6 files changed, 32 insertions(+), 32 deletions(-) create mode 100644 app/stac_api/views/__init__.py rename app/stac_api/{ => views}/views.py (99%) rename app/stac_api/{ => views}/views_mixins.py (100%) rename app/stac_api/{ => views}/views_test.py (85%) diff --git a/app/config/urls.py b/app/config/urls.py index d28f7c8b..0f481c90 100644 --- a/app/config/urls.py +++ b/app/config/urls.py @@ -38,11 +38,11 @@ def checker(request): if settings.DEBUG: import debug_toolbar - from stac_api.views_test import TestAssetUpsertHttp500 - from stac_api.views_test import TestCollectionAssetUpsertHttp500 - from stac_api.views_test import TestCollectionUpsertHttp500 - from stac_api.views_test import TestHttp500 - from stac_api.views_test import TestItemUpsertHttp500 + from stac_api.views.views_test import TestAssetUpsertHttp500 + from stac_api.views.views_test import TestCollectionAssetUpsertHttp500 + from stac_api.views.views_test import TestCollectionUpsertHttp500 + from stac_api.views.views_test import TestHttp500 + from stac_api.views.views_test import TestItemUpsertHttp500 urlpatterns = [ path('__debug__/', include(debug_toolbar.urls)), diff --git a/app/stac_api/urls.py b/app/stac_api/urls.py index 4f5e2ccd..7114a35a 100644 --- a/app/stac_api/urls.py +++ b/app/stac_api/urls.py @@ -4,28 +4,28 @@ from rest_framework.authtoken.views import obtain_auth_token -from stac_api.views import AssetDetail -from stac_api.views import AssetsList -from stac_api.views import AssetUploadAbort -from stac_api.views import AssetUploadComplete -from stac_api.views import AssetUploadDetail -from stac_api.views import AssetUploadPartsList -from stac_api.views import AssetUploadsList -from stac_api.views import CollectionAssetDetail -from stac_api.views import CollectionAssetsList -from stac_api.views import CollectionAssetUploadAbort -from stac_api.views import CollectionAssetUploadComplete -from stac_api.views import CollectionAssetUploadDetail -from stac_api.views import CollectionAssetUploadPartsList -from stac_api.views import CollectionAssetUploadsList -from stac_api.views import CollectionDetail -from stac_api.views import CollectionList -from stac_api.views import ConformancePageDetail -from stac_api.views import ItemDetail -from stac_api.views import ItemsList -from stac_api.views import LandingPageDetail -from stac_api.views import SearchList -from stac_api.views import recalculate_extent +from stac_api.views.views import AssetDetail +from stac_api.views.views import AssetsList +from stac_api.views.views import AssetUploadAbort +from stac_api.views.views import AssetUploadComplete +from stac_api.views.views import AssetUploadDetail +from stac_api.views.views import AssetUploadPartsList +from stac_api.views.views import AssetUploadsList +from stac_api.views.views import CollectionAssetDetail +from stac_api.views.views import CollectionAssetsList +from stac_api.views.views import CollectionAssetUploadAbort +from stac_api.views.views import CollectionAssetUploadComplete +from stac_api.views.views import CollectionAssetUploadDetail +from stac_api.views.views import CollectionAssetUploadPartsList +from stac_api.views.views import CollectionAssetUploadsList +from stac_api.views.views import CollectionDetail +from stac_api.views.views import CollectionList +from stac_api.views.views import ConformancePageDetail +from stac_api.views.views import ItemDetail +from stac_api.views.views import ItemsList +from stac_api.views.views import LandingPageDetail +from stac_api.views.views import SearchList +from stac_api.views.views import recalculate_extent # HEALTHCHECK_ENDPOINT = settings.HEALTHCHECK_ENDPOINT diff --git a/app/stac_api/views/__init__.py b/app/stac_api/views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/stac_api/views.py b/app/stac_api/views/views.py similarity index 99% rename from app/stac_api/views.py rename to app/stac_api/views/views.py index f8ea2350..7c1ccd57 100644 --- a/app/stac_api/views.py +++ b/app/stac_api/views/views.py @@ -24,7 +24,6 @@ from rest_framework.response import Response from rest_framework_condition import etag -from stac_api import views_mixins from stac_api.exceptions import UploadInProgressError from stac_api.exceptions import UploadNotInProgressError from stac_api.models import Asset @@ -61,6 +60,7 @@ from stac_api.validators_view import validate_collection_asset from stac_api.validators_view import validate_item from stac_api.validators_view import validate_renaming +from stac_api.views import views_mixins logger = logging.getLogger(__name__) diff --git a/app/stac_api/views_mixins.py b/app/stac_api/views/views_mixins.py similarity index 100% rename from app/stac_api/views_mixins.py rename to app/stac_api/views/views_mixins.py diff --git a/app/stac_api/views_test.py b/app/stac_api/views/views_test.py similarity index 85% rename from app/stac_api/views_test.py rename to app/stac_api/views/views_test.py index d1beefd2..6cfe4201 100644 --- a/app/stac_api/views_test.py +++ b/app/stac_api/views/views_test.py @@ -3,10 +3,10 @@ from rest_framework import generics from stac_api.models import LandingPage -from stac_api.views import AssetDetail -from stac_api.views import CollectionAssetDetail -from stac_api.views import CollectionDetail -from stac_api.views import ItemDetail +from stac_api.views.views import AssetDetail +from stac_api.views.views import CollectionAssetDetail +from stac_api.views.views import CollectionDetail +from stac_api.views.views import ItemDetail logger = logging.getLogger(__name__) From b5f0883457eef8492e17991a4cf683a3912fd5ae Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Tue, 10 Sep 2024 17:16:26 +0200 Subject: [PATCH 06/16] Move collection views to own file --- app/stac_api/urls.py | 4 +- app/stac_api/views/collection.py | 119 +++++++++++++++++++++++++++++++ app/stac_api/views/views.py | 103 -------------------------- app/stac_api/views/views_test.py | 2 +- 4 files changed, 122 insertions(+), 106 deletions(-) create mode 100644 app/stac_api/views/collection.py diff --git a/app/stac_api/urls.py b/app/stac_api/urls.py index 7114a35a..c9011901 100644 --- a/app/stac_api/urls.py +++ b/app/stac_api/urls.py @@ -4,6 +4,8 @@ from rest_framework.authtoken.views import obtain_auth_token +from stac_api.views.collection import CollectionDetail +from stac_api.views.collection import CollectionList from stac_api.views.views import AssetDetail from stac_api.views.views import AssetsList from stac_api.views.views import AssetUploadAbort @@ -18,8 +20,6 @@ from stac_api.views.views import CollectionAssetUploadDetail from stac_api.views.views import CollectionAssetUploadPartsList from stac_api.views.views import CollectionAssetUploadsList -from stac_api.views.views import CollectionDetail -from stac_api.views.views import CollectionList from stac_api.views.views import ConformancePageDetail from stac_api.views.views import ItemDetail from stac_api.views.views import ItemsList diff --git a/app/stac_api/views/collection.py b/app/stac_api/views/collection.py new file mode 100644 index 00000000..366e9c24 --- /dev/null +++ b/app/stac_api/views/collection.py @@ -0,0 +1,119 @@ +import logging + +from django.conf import settings + +from rest_framework import generics +from rest_framework import mixins +from rest_framework.response import Response +from rest_framework_condition import etag + +from stac_api.models import Collection +from stac_api.serializers import CollectionSerializer +from stac_api.serializers_utils import get_relation_links +from stac_api.validators_view import validate_renaming +from stac_api.views import views_mixins +from stac_api.views.views import get_etag + +logger = logging.getLogger(__name__) + + +def get_collection_etag(request, *args, **kwargs): + '''Get the ETag for a collection object + + The ETag is an UUID4 computed on each object changes (including relations; provider and links) + ''' + tag = get_etag(Collection.objects.filter(name=kwargs['collection_name'])) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from get_collection_etag():\n%s", + Collection.objects.filter(name=kwargs['collection_name'] + ).explain(verbose=True, analyze=True) + ) + logger.debug( + "The corresponding SQL statement:\n%s", + Collection.objects.filter(name=kwargs['collection_name']).query + ) + + return tag + + +class CollectionList(generics.GenericAPIView): + name = 'collections-list' # this name must match the name in urls.py + serializer_class = CollectionSerializer + # prefetch_related is a performance optimization to reduce the number + # of DB queries. + # see https://docs.djangoproject.com/en/3.1/ref/models/querysets/#prefetch-related + queryset = Collection.objects.filter(published=True).prefetch_related('providers', 'links') + ordering = ['name'] + + def get(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + else: + serializer = self.get_serializer(queryset, many=True) + + data = {'collections': serializer.data, 'links': get_relation_links(request, self.name)} + + if page is not None: + return self.get_paginated_response(data) + return Response(data) + + +class CollectionDetail( + generics.GenericAPIView, + mixins.RetrieveModelMixin, + views_mixins.UpdateInsertModelMixin, + views_mixins.DestroyModelMixin +): + # this name must match the name in urls.py and is used by the DestroyModelMixin + name = 'collection-detail' + serializer_class = CollectionSerializer + lookup_url_kwarg = "collection_name" + lookup_field = "name" + queryset = Collection.objects.all().prefetch_related('providers', 'links') + + @etag(get_collection_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_etag) + def put(self, request, *args, **kwargs): + return self.upsert(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_etag) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_etag) + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + def perform_upsert(self, serializer, lookup): + validate_renaming( + serializer, + self.kwargs['collection_name'], + extra_log={ + # pylint: disable=protected-access + 'request': self.request._request, + 'collection': self.kwargs['collection_name'] + } + ) + return super().perform_upsert(serializer, lookup) + + def perform_update(self, serializer, *args, **kwargs): + validate_renaming( + serializer, + self.kwargs['collection_name'], + extra_log={ + # pylint: disable=protected-access + 'request': self.request._request, + 'collection': self.kwargs['collection_name'] + } + ) + return super().perform_update(serializer, *args, **kwargs) diff --git a/app/stac_api/views/views.py b/app/stac_api/views/views.py index 7c1ccd57..43b08719 100644 --- a/app/stac_api/views/views.py +++ b/app/stac_api/views/views.py @@ -42,7 +42,6 @@ from stac_api.serializers import AssetUploadSerializer from stac_api.serializers import CollectionAssetSerializer from stac_api.serializers import CollectionAssetUploadSerializer -from stac_api.serializers import CollectionSerializer from stac_api.serializers import ConformancePageSerializer from stac_api.serializers import ItemSerializer from stac_api.serializers import LandingPageSerializer @@ -71,27 +70,6 @@ def get_etag(queryset): return None -def get_collection_etag(request, *args, **kwargs): - '''Get the ETag for a collection object - - The ETag is an UUID4 computed on each object changes (including relations; provider and links) - ''' - tag = get_etag(Collection.objects.filter(name=kwargs['collection_name'])) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from get_collection_etag():\n%s", - Collection.objects.filter(name=kwargs['collection_name'] - ).explain(verbose=True, analyze=True) - ) - logger.debug( - "The corresponding SQL statement:\n%s", - Collection.objects.filter(name=kwargs['collection_name']).query - ) - - return tag - - def get_item_etag(request, *args, **kwargs): '''Get the ETag for a item object @@ -316,30 +294,6 @@ def post(self, request, *args, **kwargs): return response -class CollectionList(generics.GenericAPIView): - name = 'collections-list' # this name must match the name in urls.py - serializer_class = CollectionSerializer - # prefetch_related is a performance optimization to reduce the number - # of DB queries. - # see https://docs.djangoproject.com/en/3.1/ref/models/querysets/#prefetch-related - queryset = Collection.objects.filter(published=True).prefetch_related('providers', 'links') - ordering = ['name'] - - def get(self, request, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) - page = self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - else: - serializer = self.get_serializer(queryset, many=True) - - data = {'collections': serializer.data, 'links': get_relation_links(request, self.name)} - - if page is not None: - return self.get_paginated_response(data) - return Response(data) - - @api_view(['POST']) @permission_classes((permissions.AllowAny,)) def recalculate_extent(request): @@ -347,63 +301,6 @@ def recalculate_extent(request): return Response() -class CollectionDetail( - generics.GenericAPIView, - mixins.RetrieveModelMixin, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin -): - # this name must match the name in urls.py and is used by the DestroyModelMixin - name = 'collection-detail' - serializer_class = CollectionSerializer - lookup_url_kwarg = "collection_name" - lookup_field = "name" - queryset = Collection.objects.all().prefetch_related('providers', 'links') - - @etag(get_collection_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_collection_etag) - def put(self, request, *args, **kwargs): - return self.upsert(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_collection_etag) - def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_collection_etag) - def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) - - def perform_upsert(self, serializer, lookup): - validate_renaming( - serializer, - self.kwargs['collection_name'], - extra_log={ - # pylint: disable=protected-access - 'request': self.request._request, - 'collection': self.kwargs['collection_name'] - } - ) - return super().perform_upsert(serializer, lookup) - - def perform_update(self, serializer, *args, **kwargs): - validate_renaming( - serializer, - self.kwargs['collection_name'], - extra_log={ - # pylint: disable=protected-access - 'request': self.request._request, - 'collection': self.kwargs['collection_name'] - } - ) - return super().perform_update(serializer, *args, **kwargs) - - class ItemsList(generics.GenericAPIView): serializer_class = ItemSerializer ordering = ['name'] diff --git a/app/stac_api/views/views_test.py b/app/stac_api/views/views_test.py index 6cfe4201..7d8c89a3 100644 --- a/app/stac_api/views/views_test.py +++ b/app/stac_api/views/views_test.py @@ -3,9 +3,9 @@ from rest_framework import generics from stac_api.models import LandingPage +from stac_api.views.collection import CollectionDetail from stac_api.views.views import AssetDetail from stac_api.views.views import CollectionAssetDetail -from stac_api.views.views import CollectionDetail from stac_api.views.views import ItemDetail logger = logging.getLogger(__name__) From c36300f9f3c87e7c7004e3dc02d245e7a9d73f78 Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Tue, 10 Sep 2024 17:24:42 +0200 Subject: [PATCH 07/16] Move item views to own file --- app/stac_api/urls.py | 4 +- app/stac_api/views/item.py | 186 +++++++++++++++++++++++++++++++ app/stac_api/views/views.py | 162 --------------------------- app/stac_api/views/views_test.py | 2 +- 4 files changed, 189 insertions(+), 165 deletions(-) create mode 100644 app/stac_api/views/item.py diff --git a/app/stac_api/urls.py b/app/stac_api/urls.py index c9011901..24e9b0a0 100644 --- a/app/stac_api/urls.py +++ b/app/stac_api/urls.py @@ -6,6 +6,8 @@ from stac_api.views.collection import CollectionDetail from stac_api.views.collection import CollectionList +from stac_api.views.item import ItemDetail +from stac_api.views.item import ItemsList from stac_api.views.views import AssetDetail from stac_api.views.views import AssetsList from stac_api.views.views import AssetUploadAbort @@ -21,8 +23,6 @@ from stac_api.views.views import CollectionAssetUploadPartsList from stac_api.views.views import CollectionAssetUploadsList from stac_api.views.views import ConformancePageDetail -from stac_api.views.views import ItemDetail -from stac_api.views.views import ItemsList from stac_api.views.views import LandingPageDetail from stac_api.views.views import SearchList from stac_api.views.views import recalculate_extent diff --git a/app/stac_api/views/item.py b/app/stac_api/views/item.py new file mode 100644 index 00000000..a9113ab6 --- /dev/null +++ b/app/stac_api/views/item.py @@ -0,0 +1,186 @@ +import logging +from datetime import datetime + +from django.conf import settings +from django.db.models import Prefetch +from django.db.models import Q +from django.utils import timezone + +from rest_framework import generics +from rest_framework.generics import get_object_or_404 +from rest_framework.response import Response +from rest_framework_condition import etag + +from stac_api.models import Asset +from stac_api.models import Collection +from stac_api.models import Item +from stac_api.serializers import ItemSerializer +from stac_api.serializers_utils import get_relation_links +from stac_api.utils import utc_aware +from stac_api.validators_view import validate_collection +from stac_api.validators_view import validate_renaming +from stac_api.views import views_mixins +from stac_api.views.views import get_etag + +logger = logging.getLogger(__name__) + + +def get_item_etag(request, *args, **kwargs): + '''Get the ETag for a item object + + The ETag is an UUID4 computed on each object changes (including relations; assets and links) + ''' + tag = get_etag( + Item.objects.filter(collection__name=kwargs['collection_name'], name=kwargs['item_name']) + ) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from get_item_etag():\n%s", + Item.objects.filter( + collection__name=kwargs['collection_name'], name=kwargs['item_name'] + ).explain(verbose=True, analyze=True) + ) + logger.debug( + "The corresponding SQL statement:\n%s", + Item.objects.filter( + collection__name=kwargs['collection_name'], name=kwargs['item_name'] + ).query + ) + + return tag + + +class ItemsList(generics.GenericAPIView): + serializer_class = ItemSerializer + ordering = ['name'] + name = 'items-list' # this name must match the name in urls.py + + def get_queryset(self): + # filter based on the url + queryset = Item.objects.filter( + # filter expired items + Q(properties_expires__gte=timezone.now()) | Q(properties_expires=None), + collection__name=self.kwargs['collection_name'] + ).prefetch_related(Prefetch('assets', queryset=Asset.objects.order_by('name')), 'links') + bbox = self.request.query_params.get('bbox', None) + date_time = self.request.query_params.get('datetime', None) + + if bbox: + queryset = queryset.filter_by_bbox(bbox) + + if date_time: + queryset = queryset.filter_by_datetime(date_time) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from ItemList() view:\n%s", + queryset.explain(verbose=True, analyze=True) + ) + logger.debug("The corresponding SQL statement:\n%s", queryset.query) + + return queryset + + def list(self, request, *args, **kwargs): + validate_collection(self.kwargs) + queryset = self.filter_queryset(self.get_queryset()) + update_interval = Collection.objects.values('update_interval').get( + name=self.kwargs['collection_name'] + )['update_interval'] + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + else: + serializer = self.get_serializer(queryset, many=True) + + data = { + 'type': 'FeatureCollection', + 'timeStamp': utc_aware(datetime.utcnow()), + 'features': serializer.data, + 'links': get_relation_links(request, self.name, [self.kwargs['collection_name']]) + } + + if page is not None: + response = self.get_paginated_response(data) + response = Response(data) + views_mixins.patch_cache_settings_by_update_interval(response, update_interval) + return response + + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + +class ItemDetail( + generics.GenericAPIView, + views_mixins.RetrieveModelDynCacheMixin, + views_mixins.UpdateInsertModelMixin, + views_mixins.DestroyModelMixin +): + # this name must match the name in urls.py and is used by the DestroyModelMixin + name = 'item-detail' + serializer_class = ItemSerializer + lookup_url_kwarg = "item_name" + lookup_field = "name" + + def get_queryset(self): + # filter based on the url + queryset = Item.objects.filter( + # filter expired items + Q(properties_expires__gte=timezone.now()) | Q(properties_expires=None), + collection__name=self.kwargs['collection_name'] + ).prefetch_related(Prefetch('assets', queryset=Asset.objects.order_by('name')), 'links') + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from ItemDetail() view:\n%s", + queryset.explain(verbose=True, analyze=True) + ) + logger.debug("The corresponding SQL statement:\n%s", queryset.query) + + return queryset + + def perform_update(self, serializer): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + self.kwargs['item_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'item': self.kwargs['item_name'] + } + ) + serializer.save(collection=collection) + + def perform_upsert(self, serializer, lookup): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + self.kwargs['item_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'item': self.kwargs['item_name'] + } + ) + lookup['collection__name'] = collection.name + return serializer.upsert(lookup, collection=collection) + + @etag(get_item_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_item_etag) + def put(self, request, *args, **kwargs): + return self.upsert(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_item_etag) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_item_etag) + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) diff --git a/app/stac_api/views/views.py b/app/stac_api/views/views.py index 43b08719..391427fe 100644 --- a/app/stac_api/views/views.py +++ b/app/stac_api/views/views.py @@ -7,7 +7,6 @@ from django.db import IntegrityError from django.db import transaction from django.db.models import Min -from django.db.models import Prefetch from django.db.models import Q from django.utils import timezone from django.utils.translation import gettext_lazy as _ @@ -70,32 +69,6 @@ def get_etag(queryset): return None -def get_item_etag(request, *args, **kwargs): - '''Get the ETag for a item object - - The ETag is an UUID4 computed on each object changes (including relations; assets and links) - ''' - tag = get_etag( - Item.objects.filter(collection__name=kwargs['collection_name'], name=kwargs['item_name']) - ) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from get_item_etag():\n%s", - Item.objects.filter( - collection__name=kwargs['collection_name'], name=kwargs['item_name'] - ).explain(verbose=True, analyze=True) - ) - logger.debug( - "The corresponding SQL statement:\n%s", - Item.objects.filter( - collection__name=kwargs['collection_name'], name=kwargs['item_name'] - ).query - ) - - return tag - - def get_collection_asset_etag(request, *args, **kwargs): '''Get the ETag for a collection asset object @@ -301,141 +274,6 @@ def recalculate_extent(request): return Response() -class ItemsList(generics.GenericAPIView): - serializer_class = ItemSerializer - ordering = ['name'] - name = 'items-list' # this name must match the name in urls.py - - def get_queryset(self): - # filter based on the url - queryset = Item.objects.filter( - # filter expired items - Q(properties_expires__gte=timezone.now()) | Q(properties_expires=None), - collection__name=self.kwargs['collection_name'] - ).prefetch_related(Prefetch('assets', queryset=Asset.objects.order_by('name')), 'links') - bbox = self.request.query_params.get('bbox', None) - date_time = self.request.query_params.get('datetime', None) - - if bbox: - queryset = queryset.filter_by_bbox(bbox) - - if date_time: - queryset = queryset.filter_by_datetime(date_time) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from ItemList() view:\n%s", - queryset.explain(verbose=True, analyze=True) - ) - logger.debug("The corresponding SQL statement:\n%s", queryset.query) - - return queryset - - def list(self, request, *args, **kwargs): - validate_collection(self.kwargs) - queryset = self.filter_queryset(self.get_queryset()) - update_interval = Collection.objects.values('update_interval').get( - name=self.kwargs['collection_name'] - )['update_interval'] - page = self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - else: - serializer = self.get_serializer(queryset, many=True) - - data = { - 'type': 'FeatureCollection', - 'timeStamp': utc_aware(datetime.utcnow()), - 'features': serializer.data, - 'links': get_relation_links(request, self.name, [self.kwargs['collection_name']]) - } - - if page is not None: - response = self.get_paginated_response(data) - response = Response(data) - views_mixins.patch_cache_settings_by_update_interval(response, update_interval) - return response - - def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) - - -class ItemDetail( - generics.GenericAPIView, - views_mixins.RetrieveModelDynCacheMixin, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin -): - # this name must match the name in urls.py and is used by the DestroyModelMixin - name = 'item-detail' - serializer_class = ItemSerializer - lookup_url_kwarg = "item_name" - lookup_field = "name" - - def get_queryset(self): - # filter based on the url - queryset = Item.objects.filter( - # filter expired items - Q(properties_expires__gte=timezone.now()) | Q(properties_expires=None), - collection__name=self.kwargs['collection_name'] - ).prefetch_related(Prefetch('assets', queryset=Asset.objects.order_by('name')), 'links') - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from ItemDetail() view:\n%s", - queryset.explain(verbose=True, analyze=True) - ) - logger.debug("The corresponding SQL statement:\n%s", queryset.query) - - return queryset - - def perform_update(self, serializer): - collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) - validate_renaming( - serializer, - self.kwargs['item_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'item': self.kwargs['item_name'] - } - ) - serializer.save(collection=collection) - - def perform_upsert(self, serializer, lookup): - collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) - validate_renaming( - serializer, - self.kwargs['item_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'item': self.kwargs['item_name'] - } - ) - lookup['collection__name'] = collection.name - return serializer.upsert(lookup, collection=collection) - - @etag(get_item_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_item_etag) - def put(self, request, *args, **kwargs): - return self.upsert(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_item_etag) - def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_item_etag) - def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) - - class AssetsList(generics.GenericAPIView): name = 'assets-list' # this name must match the name in urls.py serializer_class = AssetSerializer diff --git a/app/stac_api/views/views_test.py b/app/stac_api/views/views_test.py index 7d8c89a3..78f7d4ae 100644 --- a/app/stac_api/views/views_test.py +++ b/app/stac_api/views/views_test.py @@ -4,9 +4,9 @@ from stac_api.models import LandingPage from stac_api.views.collection import CollectionDetail +from stac_api.views.item import ItemDetail from stac_api.views.views import AssetDetail from stac_api.views.views import CollectionAssetDetail -from stac_api.views.views import ItemDetail logger = logging.getLogger(__name__) From 0cece9e5da18ae47efcc7568dd2366bda5ef13f6 Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Wed, 11 Sep 2024 07:31:47 +0200 Subject: [PATCH 08/16] Move asset views Move asset views to item file, move collection asset views to collection file. --- app/stac_api/urls.py | 8 +- app/stac_api/views/collection.py | 142 ++++++++++++++ app/stac_api/views/item.py | 177 +++++++++++++++++ app/stac_api/views/views.py | 319 ------------------------------- app/stac_api/views/views_test.py | 4 +- 5 files changed, 325 insertions(+), 325 deletions(-) diff --git a/app/stac_api/urls.py b/app/stac_api/urls.py index 24e9b0a0..95b2ba46 100644 --- a/app/stac_api/urls.py +++ b/app/stac_api/urls.py @@ -4,19 +4,19 @@ from rest_framework.authtoken.views import obtain_auth_token +from stac_api.views.collection import CollectionAssetDetail +from stac_api.views.collection import CollectionAssetsList from stac_api.views.collection import CollectionDetail from stac_api.views.collection import CollectionList +from stac_api.views.item import AssetDetail +from stac_api.views.item import AssetsList from stac_api.views.item import ItemDetail from stac_api.views.item import ItemsList -from stac_api.views.views import AssetDetail -from stac_api.views.views import AssetsList from stac_api.views.views import AssetUploadAbort from stac_api.views.views import AssetUploadComplete from stac_api.views.views import AssetUploadDetail from stac_api.views.views import AssetUploadPartsList from stac_api.views.views import AssetUploadsList -from stac_api.views.views import CollectionAssetDetail -from stac_api.views.views import CollectionAssetsList from stac_api.views.views import CollectionAssetUploadAbort from stac_api.views.views import CollectionAssetUploadComplete from stac_api.views.views import CollectionAssetUploadDetail diff --git a/app/stac_api/views/collection.py b/app/stac_api/views/collection.py index 366e9c24..185a3d6b 100644 --- a/app/stac_api/views/collection.py +++ b/app/stac_api/views/collection.py @@ -4,12 +4,17 @@ from rest_framework import generics from rest_framework import mixins +from rest_framework.generics import get_object_or_404 from rest_framework.response import Response from rest_framework_condition import etag from stac_api.models import Collection +from stac_api.models import CollectionAsset +from stac_api.serializers import CollectionAssetSerializer from stac_api.serializers import CollectionSerializer from stac_api.serializers_utils import get_relation_links +from stac_api.utils import get_collection_asset_path +from stac_api.validators_view import validate_collection from stac_api.validators_view import validate_renaming from stac_api.views import views_mixins from stac_api.views.views import get_etag @@ -38,6 +43,31 @@ def get_collection_etag(request, *args, **kwargs): return tag +def get_collection_asset_etag(request, *args, **kwargs): + '''Get the ETag for a collection asset object + + The ETag is an UUID4 computed on each object changes + ''' + tag = get_etag( + CollectionAsset.objects.filter( + collection__name=kwargs['collection_name'], name=kwargs['asset_name'] + ) + ) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from get_collection_asset_etag():\n%s", + CollectionAsset.objects.filter(name=kwargs['asset_name'] + ).explain(verbose=True, analyze=True) + ) + logger.debug( + "The corresponding SQL statement:\n%s", + CollectionAsset.objects.filter(name=kwargs['asset_name']).query + ) + + return tag + + class CollectionList(generics.GenericAPIView): name = 'collections-list' # this name must match the name in urls.py serializer_class = CollectionSerializer @@ -117,3 +147,115 @@ def perform_update(self, serializer, *args, **kwargs): } ) return super().perform_update(serializer, *args, **kwargs) + + +class CollectionAssetsList(generics.GenericAPIView): + name = 'collection-assets-list' # this name must match the name in urls.py + serializer_class = CollectionAssetSerializer + pagination_class = None + + def get_queryset(self): + # filter based on the url + return CollectionAsset.objects.filter(collection__name=self.kwargs['collection_name'] + ).order_by('name') + + def get(self, request, *args, **kwargs): + validate_collection(self.kwargs) + + queryset = self.filter_queryset(self.get_queryset()) + update_interval = Collection.objects.values('update_interval').get( + name=self.kwargs['collection_name'] + )['update_interval'] + serializer = self.get_serializer(queryset, many=True) + + data = { + 'assets': serializer.data, + 'links': get_relation_links(request, self.name, [self.kwargs['collection_name']]) + } + response = Response(data) + views_mixins.patch_cache_settings_by_update_interval(response, update_interval) + return response + + +class CollectionAssetDetail( + generics.GenericAPIView, + views_mixins.UpdateInsertModelMixin, + views_mixins.DestroyModelMixin, + views_mixins.RetrieveModelDynCacheMixin +): + # this name must match the name in urls.py and is used by the DestroyModelMixin + name = 'collection-asset-detail' + serializer_class = CollectionAssetSerializer + lookup_url_kwarg = "asset_name" + lookup_field = "name" + + def get_queryset(self): + # filter based on the url + return CollectionAsset.objects.filter(collection__name=self.kwargs['collection_name']) + + def get_serializer(self, *args, **kwargs): + serializer_class = self.get_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + serializer = serializer_class(*args, **kwargs) + + # for the validation the serializer needs to know the collection of the + # asset. In case of inserting, the asset doesn't exist and thus the collection + # can't be read from the instance, which is why we pass the collection manually + # here. See serializers.AssetBaseSerializer._validate_href_field + serializer.collection = collection + return serializer + + def perform_update(self, serializer): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'asset': self.kwargs['asset_name'] + } + ) + return serializer.save( + collection=collection, + file=get_collection_asset_path(collection, self.kwargs['asset_name']) + ) + + def perform_upsert(self, serializer, lookup): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'asset': self.kwargs['asset_name'] + } + ) + lookup['collection__name'] = collection.name + + return serializer.upsert( + lookup, + collection=collection, + file=get_collection_asset_path(collection, self.kwargs['asset_name']) + ) + + @etag(get_collection_asset_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_asset_etag) + def put(self, request, *args, **kwargs): + return self.upsert(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_asset_etag) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_asset_etag) + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) diff --git a/app/stac_api/views/item.py b/app/stac_api/views/item.py index a9113ab6..91ca23b7 100644 --- a/app/stac_api/views/item.py +++ b/app/stac_api/views/item.py @@ -14,10 +14,13 @@ from stac_api.models import Asset from stac_api.models import Collection from stac_api.models import Item +from stac_api.serializers import AssetSerializer from stac_api.serializers import ItemSerializer from stac_api.serializers_utils import get_relation_links +from stac_api.utils import get_asset_path from stac_api.utils import utc_aware from stac_api.validators_view import validate_collection +from stac_api.validators_view import validate_item from stac_api.validators_view import validate_renaming from stac_api.views import views_mixins from stac_api.views.views import get_etag @@ -51,6 +54,33 @@ def get_item_etag(request, *args, **kwargs): return tag +def get_asset_etag(request, *args, **kwargs): + '''Get the ETag for a asset object + + The ETag is an UUID4 computed on each object changes + ''' + tag = get_etag( + Asset.objects.filter( + item__collection__name=kwargs['collection_name'], + item__name=kwargs['item_name'], + name=kwargs['asset_name'] + ) + ) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from get_asset_etag():\n%s", + Asset.objects.filter(item__name=kwargs['item_name'], + name=kwargs['asset_name']).explain(verbose=True, analyze=True) + ) + logger.debug( + "The corresponding SQL statement:\n%s", + Asset.objects.filter(item__name=kwargs['item_name'], name=kwargs['asset_name']).query + ) + + return tag + + class ItemsList(generics.GenericAPIView): serializer_class = ItemSerializer ordering = ['name'] @@ -184,3 +214,150 @@ def patch(self, request, *args, **kwargs): @etag(get_item_etag) def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) + + +class AssetsList(generics.GenericAPIView): + name = 'assets-list' # this name must match the name in urls.py + serializer_class = AssetSerializer + pagination_class = None + + def get_queryset(self): + # filter based on the url + return Asset.objects.filter( + item__collection__name=self.kwargs['collection_name'], + item__name=self.kwargs['item_name'] + ).order_by('name') + + def get(self, request, *args, **kwargs): + validate_item(self.kwargs) + + queryset = self.filter_queryset(self.get_queryset()) + update_interval = Item.objects.values('update_interval').get( + collection__name=self.kwargs['collection_name'], + name=self.kwargs['item_name'], + )['update_interval'] + serializer = self.get_serializer(queryset, many=True) + + data = { + 'assets': serializer.data, + 'links': + get_relation_links( + request, self.name, [self.kwargs['collection_name'], self.kwargs['item_name']] + ) + } + response = Response(data) + views_mixins.patch_cache_settings_by_update_interval(response, update_interval) + return response + + +class AssetDetail( + generics.GenericAPIView, + views_mixins.UpdateInsertModelMixin, + views_mixins.DestroyModelMixin, + views_mixins.RetrieveModelDynCacheMixin +): + # this name must match the name in urls.py and is used by the DestroyModelMixin + name = 'asset-detail' + serializer_class = AssetSerializer + lookup_url_kwarg = "asset_name" + lookup_field = "name" + + def get_queryset(self): + # filter based on the url + return Asset.objects.filter( + Q(item__properties_expires=None) | Q(item__properties_expires__gte=timezone.now()), + item__collection__name=self.kwargs['collection_name'], + item__name=self.kwargs['item_name'] + ) + + def get_serializer(self, *args, **kwargs): + serializer_class = self.get_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + item = get_object_or_404( + Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] + ) + serializer = serializer_class(*args, **kwargs) + + # for the validation the serializer needs to know the collection of the + # item. In case of upserting, the asset doesn't exist and thus the collection + # can't be read from the instance, which is why we pass the collection manually + # here. See serialiers.AssetBaseSerializer._validate_href_field + serializer.collection = item.collection + return serializer + + def _get_file_path(self, serializer, item, asset_name): + """Get the path to the file + + If the collection allows for external asset, and the file is specified + in the request, we set it directly. If the collection doesn't allow it, + error 400. + Otherwise we assemble the path from the file name, collection name as + well as the s3 bucket domain + """ + + if 'file' in serializer.validated_data: + file = serializer.validated_data['file'] + # setting the href makes the asset be external implicitly + is_external = True + else: + file = get_asset_path(item, asset_name) + is_external = False + + return file, is_external + + def perform_update(self, serializer): + item = get_object_or_404( + Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] + ) + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'item': self.kwargs['item_name'], + 'asset': self.kwargs['asset_name'] + } + ) + file, is_external = self._get_file_path(serializer, item, self.kwargs['asset_name']) + return serializer.save(item=item, file=file, is_external=is_external) + + def perform_upsert(self, serializer, lookup): + item = get_object_or_404( + Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] + ) + + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'item': self.kwargs['item_name'], + 'asset': self.kwargs['asset_name'] + } + ) + lookup['item__collection__name'] = item.collection.name + lookup['item__name'] = item.name + + file, is_external = self._get_file_path(serializer, item, self.kwargs['asset_name']) + return serializer.upsert(lookup, item=item, file=file, is_external=is_external) + + @etag(get_asset_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_asset_etag) + def put(self, request, *args, **kwargs): + return self.upsert(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_asset_etag) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_asset_etag) + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) diff --git a/app/stac_api/views/views.py b/app/stac_api/views/views.py index 391427fe..79a1509e 100644 --- a/app/stac_api/views/views.py +++ b/app/stac_api/views/views.py @@ -7,8 +7,6 @@ from django.db import IntegrityError from django.db import transaction from django.db.models import Min -from django.db.models import Q -from django.utils import timezone from django.utils.translation import gettext_lazy as _ from rest_framework import generics @@ -28,7 +26,6 @@ from stac_api.models import Asset from stac_api.models import AssetUpload from stac_api.models import BaseAssetUpload -from stac_api.models import Collection from stac_api.models import CollectionAsset from stac_api.models import CollectionAssetUpload from stac_api.models import Item @@ -36,10 +33,8 @@ from stac_api.pagination import ExtApiPagination from stac_api.pagination import GetPostCursorPagination from stac_api.s3_multipart_upload import MultipartUpload -from stac_api.serializers import AssetSerializer from stac_api.serializers import AssetUploadPartsSerializer from stac_api.serializers import AssetUploadSerializer -from stac_api.serializers import CollectionAssetSerializer from stac_api.serializers import CollectionAssetUploadSerializer from stac_api.serializers import ConformancePageSerializer from stac_api.serializers import ItemSerializer @@ -54,10 +49,7 @@ from stac_api.utils import utc_aware from stac_api.validators_serializer import ValidateSearchRequest from stac_api.validators_view import validate_asset -from stac_api.validators_view import validate_collection from stac_api.validators_view import validate_collection_asset -from stac_api.validators_view import validate_item -from stac_api.validators_view import validate_renaming from stac_api.views import views_mixins logger = logging.getLogger(__name__) @@ -69,58 +61,6 @@ def get_etag(queryset): return None -def get_collection_asset_etag(request, *args, **kwargs): - '''Get the ETag for a collection asset object - - The ETag is an UUID4 computed on each object changes - ''' - tag = get_etag( - CollectionAsset.objects.filter( - collection__name=kwargs['collection_name'], name=kwargs['asset_name'] - ) - ) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from get_collection_asset_etag():\n%s", - CollectionAsset.objects.filter(name=kwargs['asset_name'] - ).explain(verbose=True, analyze=True) - ) - logger.debug( - "The corresponding SQL statement:\n%s", - CollectionAsset.objects.filter(name=kwargs['asset_name']).query - ) - - return tag - - -def get_asset_etag(request, *args, **kwargs): - '''Get the ETag for a asset object - - The ETag is an UUID4 computed on each object changes - ''' - tag = get_etag( - Asset.objects.filter( - item__collection__name=kwargs['collection_name'], - item__name=kwargs['item_name'], - name=kwargs['asset_name'] - ) - ) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from get_asset_etag():\n%s", - Asset.objects.filter(item__name=kwargs['item_name'], - name=kwargs['asset_name']).explain(verbose=True, analyze=True) - ) - logger.debug( - "The corresponding SQL statement:\n%s", - Asset.objects.filter(item__name=kwargs['item_name'], name=kwargs['asset_name']).query - ) - - return tag - - def get_asset_upload_etag(request, *args, **kwargs): '''Get the ETag for an asset upload object @@ -274,265 +214,6 @@ def recalculate_extent(request): return Response() -class AssetsList(generics.GenericAPIView): - name = 'assets-list' # this name must match the name in urls.py - serializer_class = AssetSerializer - pagination_class = None - - def get_queryset(self): - # filter based on the url - return Asset.objects.filter( - item__collection__name=self.kwargs['collection_name'], - item__name=self.kwargs['item_name'] - ).order_by('name') - - def get(self, request, *args, **kwargs): - validate_item(self.kwargs) - - queryset = self.filter_queryset(self.get_queryset()) - update_interval = Item.objects.values('update_interval').get( - collection__name=self.kwargs['collection_name'], - name=self.kwargs['item_name'], - )['update_interval'] - serializer = self.get_serializer(queryset, many=True) - - data = { - 'assets': serializer.data, - 'links': - get_relation_links( - request, self.name, [self.kwargs['collection_name'], self.kwargs['item_name']] - ) - } - response = Response(data) - views_mixins.patch_cache_settings_by_update_interval(response, update_interval) - return response - - -class CollectionAssetsList(generics.GenericAPIView): - name = 'collection-assets-list' # this name must match the name in urls.py - serializer_class = CollectionAssetSerializer - pagination_class = None - - def get_queryset(self): - # filter based on the url - return CollectionAsset.objects.filter(collection__name=self.kwargs['collection_name'] - ).order_by('name') - - def get(self, request, *args, **kwargs): - validate_collection(self.kwargs) - - queryset = self.filter_queryset(self.get_queryset()) - update_interval = Collection.objects.values('update_interval').get( - name=self.kwargs['collection_name'] - )['update_interval'] - serializer = self.get_serializer(queryset, many=True) - - data = { - 'assets': serializer.data, - 'links': get_relation_links(request, self.name, [self.kwargs['collection_name']]) - } - response = Response(data) - views_mixins.patch_cache_settings_by_update_interval(response, update_interval) - return response - - -class AssetDetail( - generics.GenericAPIView, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin, - views_mixins.RetrieveModelDynCacheMixin -): - # this name must match the name in urls.py and is used by the DestroyModelMixin - name = 'asset-detail' - serializer_class = AssetSerializer - lookup_url_kwarg = "asset_name" - lookup_field = "name" - - def get_queryset(self): - # filter based on the url - return Asset.objects.filter( - Q(item__properties_expires=None) | Q(item__properties_expires__gte=timezone.now()), - item__collection__name=self.kwargs['collection_name'], - item__name=self.kwargs['item_name'] - ) - - def get_serializer(self, *args, **kwargs): - serializer_class = self.get_serializer_class() - kwargs.setdefault('context', self.get_serializer_context()) - item = get_object_or_404( - Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] - ) - serializer = serializer_class(*args, **kwargs) - - # for the validation the serializer needs to know the collection of the - # item. In case of upserting, the asset doesn't exist and thus the collection - # can't be read from the instance, which is why we pass the collection manually - # here. See serialiers.AssetBaseSerializer._validate_href_field - serializer.collection = item.collection - return serializer - - def _get_file_path(self, serializer, item, asset_name): - """Get the path to the file - - If the collection allows for external asset, and the file is specified - in the request, we set it directly. If the collection doesn't allow it, - error 400. - Otherwise we assemble the path from the file name, collection name as - well as the s3 bucket domain - """ - - if 'file' in serializer.validated_data: - file = serializer.validated_data['file'] - # setting the href makes the asset be external implicitly - is_external = True - else: - file = get_asset_path(item, asset_name) - is_external = False - - return file, is_external - - def perform_update(self, serializer): - item = get_object_or_404( - Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] - ) - validate_renaming( - serializer, - original_id=self.kwargs['asset_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'item': self.kwargs['item_name'], - 'asset': self.kwargs['asset_name'] - } - ) - file, is_external = self._get_file_path(serializer, item, self.kwargs['asset_name']) - return serializer.save(item=item, file=file, is_external=is_external) - - def perform_upsert(self, serializer, lookup): - item = get_object_or_404( - Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] - ) - - validate_renaming( - serializer, - original_id=self.kwargs['asset_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'item': self.kwargs['item_name'], - 'asset': self.kwargs['asset_name'] - } - ) - lookup['item__collection__name'] = item.collection.name - lookup['item__name'] = item.name - - file, is_external = self._get_file_path(serializer, item, self.kwargs['asset_name']) - return serializer.upsert(lookup, item=item, file=file, is_external=is_external) - - @etag(get_asset_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_asset_etag) - def put(self, request, *args, **kwargs): - return self.upsert(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_asset_etag) - def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_asset_etag) - def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) - - -class CollectionAssetDetail( - generics.GenericAPIView, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin, - views_mixins.RetrieveModelDynCacheMixin -): - # this name must match the name in urls.py and is used by the DestroyModelMixin - name = 'collection-asset-detail' - serializer_class = CollectionAssetSerializer - lookup_url_kwarg = "asset_name" - lookup_field = "name" - - def get_queryset(self): - # filter based on the url - return CollectionAsset.objects.filter(collection__name=self.kwargs['collection_name']) - - def get_serializer(self, *args, **kwargs): - serializer_class = self.get_serializer_class() - kwargs.setdefault('context', self.get_serializer_context()) - collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) - serializer = serializer_class(*args, **kwargs) - - # for the validation the serializer needs to know the collection of the - # asset. In case of inserting, the asset doesn't exist and thus the collection - # can't be read from the instance, which is why we pass the collection manually - # here. See serializers.AssetBaseSerializer._validate_href_field - serializer.collection = collection - return serializer - - def perform_update(self, serializer): - collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) - validate_renaming( - serializer, - original_id=self.kwargs['asset_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'asset': self.kwargs['asset_name'] - } - ) - return serializer.save( - collection=collection, - file=get_collection_asset_path(collection, self.kwargs['asset_name']) - ) - - def perform_upsert(self, serializer, lookup): - collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) - validate_renaming( - serializer, - original_id=self.kwargs['asset_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'asset': self.kwargs['asset_name'] - } - ) - lookup['collection__name'] = collection.name - - return serializer.upsert( - lookup, - collection=collection, - file=get_collection_asset_path(collection, self.kwargs['asset_name']) - ) - - @etag(get_collection_asset_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_collection_asset_etag) - def put(self, request, *args, **kwargs): - return self.upsert(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_collection_asset_etag) - def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_collection_asset_etag) - def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) - - class SharedAssetUploadBase(generics.GenericAPIView): """SharedAssetUploadBase provides a base view for asset uploads and collection asset uploads. """ diff --git a/app/stac_api/views/views_test.py b/app/stac_api/views/views_test.py index 78f7d4ae..85b7d1ab 100644 --- a/app/stac_api/views/views_test.py +++ b/app/stac_api/views/views_test.py @@ -3,10 +3,10 @@ from rest_framework import generics from stac_api.models import LandingPage +from stac_api.views.collection import CollectionAssetDetail from stac_api.views.collection import CollectionDetail +from stac_api.views.item import AssetDetail from stac_api.views.item import ItemDetail -from stac_api.views.views import AssetDetail -from stac_api.views.views import CollectionAssetDetail logger = logging.getLogger(__name__) From 849eb2b328eb97862a830b1081d5d4863e273546 Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Wed, 11 Sep 2024 07:39:27 +0200 Subject: [PATCH 09/16] Move serializer files to subfolder --- .../management/commands/list_asset_uploads.py | 2 +- .../management/commands/profile_item_serializer.py | 2 +- .../commands/profile_serializer_vs_no_drf.py | 2 +- app/stac_api/serializers/__init__.py | 0 app/stac_api/{ => serializers}/serializers.py | 10 +++++----- .../{ => serializers}/serializers_utils.py | 0 app/stac_api/views/collection.py | 6 +++--- app/stac_api/views/item.py | 6 +++--- app/stac_api/views/views.py | 14 +++++++------- app/stac_api/views/views_mixins.py | 2 +- app/tests/tests_09/test_serializer.py | 6 +++--- app/tests/tests_09/test_serializer_asset_upload.py | 2 +- app/tests/tests_10/test_serializer.py | 6 +++--- app/tests/tests_10/test_serializer_asset_upload.py | 2 +- scripts/fill_local_db.py | 8 ++++---- 15 files changed, 34 insertions(+), 34 deletions(-) create mode 100644 app/stac_api/serializers/__init__.py rename app/stac_api/{ => serializers}/serializers.py (99%) rename app/stac_api/{ => serializers}/serializers_utils.py (100%) diff --git a/app/stac_api/management/commands/list_asset_uploads.py b/app/stac_api/management/commands/list_asset_uploads.py index 19c90ece..ff4d1b83 100644 --- a/app/stac_api/management/commands/list_asset_uploads.py +++ b/app/stac_api/management/commands/list_asset_uploads.py @@ -6,7 +6,7 @@ from stac_api.models import AssetUpload from stac_api.s3_multipart_upload import MultipartUpload -from stac_api.serializers import AssetUploadSerializer +from stac_api.serializers.serializers import AssetUploadSerializer from stac_api.utils import CommandHandler from stac_api.utils import get_asset_path diff --git a/app/stac_api/management/commands/profile_item_serializer.py b/app/stac_api/management/commands/profile_item_serializer.py index 84ce4cf8..b348b175 100644 --- a/app/stac_api/management/commands/profile_item_serializer.py +++ b/app/stac_api/management/commands/profile_item_serializer.py @@ -20,7 +20,7 @@ class Handler(CommandHandler): def profiling(self): # pylint: disable=import-outside-toplevel,possibly-unused-variable - from stac_api.serializers import ItemSerializer + from stac_api.serializers.serializers import ItemSerializer collection_id = self.options["collection"] qs = Item.objects.filter(collection__name=collection_id ).prefetch_related('assets', 'links')[:self.options['limit']] diff --git a/app/stac_api/management/commands/profile_serializer_vs_no_drf.py b/app/stac_api/management/commands/profile_serializer_vs_no_drf.py index 64da7525..6d580a6c 100644 --- a/app/stac_api/management/commands/profile_serializer_vs_no_drf.py +++ b/app/stac_api/management/commands/profile_serializer_vs_no_drf.py @@ -21,7 +21,7 @@ def profiling(self): # pylint: disable=import-outside-toplevel,possibly-unused-variable self.print('Starting profiling') - from stac_api.serializers import ItemSerializer + from stac_api.serializers.serializers import ItemSerializer def serialize(qs): return { diff --git a/app/stac_api/serializers/__init__.py b/app/stac_api/serializers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/stac_api/serializers.py b/app/stac_api/serializers/serializers.py similarity index 99% rename from app/stac_api/serializers.py rename to app/stac_api/serializers/serializers.py index a01a572c..30f0611b 100644 --- a/app/stac_api/serializers.py +++ b/app/stac_api/serializers/serializers.py @@ -23,11 +23,11 @@ from stac_api.models import LandingPage from stac_api.models import LandingPageLink from stac_api.models import Provider -from stac_api.serializers_utils import DictSerializer -from stac_api.serializers_utils import NonNullModelSerializer -from stac_api.serializers_utils import UpsertModelSerializerMixin -from stac_api.serializers_utils import get_relation_links -from stac_api.serializers_utils import update_or_create_links +from stac_api.serializers.serializers_utils import DictSerializer +from stac_api.serializers.serializers_utils import NonNullModelSerializer +from stac_api.serializers.serializers_utils import UpsertModelSerializerMixin +from stac_api.serializers.serializers_utils import get_relation_links +from stac_api.serializers.serializers_utils import update_or_create_links from stac_api.utils import build_asset_href from stac_api.utils import get_browser_url from stac_api.utils import get_stac_version diff --git a/app/stac_api/serializers_utils.py b/app/stac_api/serializers/serializers_utils.py similarity index 100% rename from app/stac_api/serializers_utils.py rename to app/stac_api/serializers/serializers_utils.py diff --git a/app/stac_api/views/collection.py b/app/stac_api/views/collection.py index 185a3d6b..6cbb9b56 100644 --- a/app/stac_api/views/collection.py +++ b/app/stac_api/views/collection.py @@ -10,9 +10,9 @@ from stac_api.models import Collection from stac_api.models import CollectionAsset -from stac_api.serializers import CollectionAssetSerializer -from stac_api.serializers import CollectionSerializer -from stac_api.serializers_utils import get_relation_links +from stac_api.serializers.serializers import CollectionAssetSerializer +from stac_api.serializers.serializers import CollectionSerializer +from stac_api.serializers.serializers_utils import get_relation_links from stac_api.utils import get_collection_asset_path from stac_api.validators_view import validate_collection from stac_api.validators_view import validate_renaming diff --git a/app/stac_api/views/item.py b/app/stac_api/views/item.py index 91ca23b7..a760ec37 100644 --- a/app/stac_api/views/item.py +++ b/app/stac_api/views/item.py @@ -14,9 +14,9 @@ from stac_api.models import Asset from stac_api.models import Collection from stac_api.models import Item -from stac_api.serializers import AssetSerializer -from stac_api.serializers import ItemSerializer -from stac_api.serializers_utils import get_relation_links +from stac_api.serializers.serializers import AssetSerializer +from stac_api.serializers.serializers import ItemSerializer +from stac_api.serializers.serializers_utils import get_relation_links from stac_api.utils import get_asset_path from stac_api.utils import utc_aware from stac_api.validators_view import validate_collection diff --git a/app/stac_api/views/views.py b/app/stac_api/views/views.py index 79a1509e..d74d4195 100644 --- a/app/stac_api/views/views.py +++ b/app/stac_api/views/views.py @@ -33,13 +33,13 @@ from stac_api.pagination import ExtApiPagination from stac_api.pagination import GetPostCursorPagination from stac_api.s3_multipart_upload import MultipartUpload -from stac_api.serializers import AssetUploadPartsSerializer -from stac_api.serializers import AssetUploadSerializer -from stac_api.serializers import CollectionAssetUploadSerializer -from stac_api.serializers import ConformancePageSerializer -from stac_api.serializers import ItemSerializer -from stac_api.serializers import LandingPageSerializer -from stac_api.serializers_utils import get_relation_links +from stac_api.serializers.serializers import AssetUploadPartsSerializer +from stac_api.serializers.serializers import AssetUploadSerializer +from stac_api.serializers.serializers import CollectionAssetUploadSerializer +from stac_api.serializers.serializers import ConformancePageSerializer +from stac_api.serializers.serializers import ItemSerializer +from stac_api.serializers.serializers import LandingPageSerializer +from stac_api.serializers.serializers_utils import get_relation_links from stac_api.utils import call_calculate_extent from stac_api.utils import get_asset_path from stac_api.utils import get_collection_asset_path diff --git a/app/stac_api/views/views_mixins.py b/app/stac_api/views/views_mixins.py index 2c63c70f..8d75b186 100644 --- a/app/stac_api/views/views_mixins.py +++ b/app/stac_api/views/views_mixins.py @@ -11,7 +11,7 @@ from rest_framework import status from rest_framework.response import Response -from stac_api.serializers_utils import get_parent_link +from stac_api.serializers.serializers_utils import get_parent_link from stac_api.utils import get_dynamic_max_age_value from stac_api.utils import get_link diff --git a/app/tests/tests_09/test_serializer.py b/app/tests/tests_09/test_serializer.py index bd9a40c2..0ec7f4a0 100644 --- a/app/tests/tests_09/test_serializer.py +++ b/app/tests/tests_09/test_serializer.py @@ -13,9 +13,9 @@ from rest_framework.test import APIRequestFactory from stac_api.models import get_asset_path -from stac_api.serializers import AssetSerializer -from stac_api.serializers import CollectionSerializer -from stac_api.serializers import ItemSerializer +from stac_api.serializers.serializers import AssetSerializer +from stac_api.serializers.serializers import CollectionSerializer +from stac_api.serializers.serializers import ItemSerializer from stac_api.utils import get_link from stac_api.utils import isoformat from stac_api.utils import utc_aware diff --git a/app/tests/tests_09/test_serializer_asset_upload.py b/app/tests/tests_09/test_serializer_asset_upload.py index 7d0e8ce7..b1287734 100644 --- a/app/tests/tests_09/test_serializer_asset_upload.py +++ b/app/tests/tests_09/test_serializer_asset_upload.py @@ -7,7 +7,7 @@ from rest_framework.exceptions import ValidationError from stac_api.models import AssetUpload -from stac_api.serializers import AssetUploadSerializer +from stac_api.serializers.serializers import AssetUploadSerializer from stac_api.utils import get_sha256_multihash from tests.tests_09.base_test import STAC_BASE_V diff --git a/app/tests/tests_10/test_serializer.py b/app/tests/tests_10/test_serializer.py index e13096ae..aaa6aa30 100644 --- a/app/tests/tests_10/test_serializer.py +++ b/app/tests/tests_10/test_serializer.py @@ -13,9 +13,9 @@ from rest_framework.test import APIRequestFactory from stac_api.models import get_asset_path -from stac_api.serializers import AssetSerializer -from stac_api.serializers import CollectionSerializer -from stac_api.serializers import ItemSerializer +from stac_api.serializers.serializers import AssetSerializer +from stac_api.serializers.serializers import CollectionSerializer +from stac_api.serializers.serializers import ItemSerializer from stac_api.utils import get_link from stac_api.utils import isoformat from stac_api.utils import utc_aware diff --git a/app/tests/tests_10/test_serializer_asset_upload.py b/app/tests/tests_10/test_serializer_asset_upload.py index b389b256..9c7c30ab 100644 --- a/app/tests/tests_10/test_serializer_asset_upload.py +++ b/app/tests/tests_10/test_serializer_asset_upload.py @@ -7,7 +7,7 @@ from rest_framework.exceptions import ValidationError from stac_api.models import AssetUpload -from stac_api.serializers import AssetUploadSerializer +from stac_api.serializers.serializers import AssetUploadSerializer from stac_api.utils import get_sha256_multihash from tests.tests_10.base_test import StacBaseTestCase diff --git a/scripts/fill_local_db.py b/scripts/fill_local_db.py index ca8d6871..d219b096 100644 --- a/scripts/fill_local_db.py +++ b/scripts/fill_local_db.py @@ -20,10 +20,10 @@ from rest_framework.renderers import JSONRenderer from rest_framework.parsers import JSONParser from stac_api.models import * -from stac_api.serializers import CollectionSerializer -from stac_api.serializers import CollectionSerializer -from stac_api.serializers import LinkSerializer -from stac_api.serializers import ProviderSerializer +from stac_api.serializers.serializers import CollectionSerializer +from stac_api.serializers.serializers import CollectionSerializer +from stac_api.serializers.serializers import LinkSerializer +from stac_api.serializers.serializers import ProviderSerializer # create link instances for testing From 4e6b42843c361a5fd771bf8680a83f90b4c05e8f Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Wed, 11 Sep 2024 07:48:55 +0200 Subject: [PATCH 10/16] Move collection serializers to own file --- app/stac_api/serializers/collection.py | 413 ++++++++++++++++++++++++ app/stac_api/serializers/serializers.py | 391 ---------------------- app/stac_api/views/collection.py | 4 +- app/tests/tests_09/test_serializer.py | 2 +- app/tests/tests_10/test_serializer.py | 2 +- 5 files changed, 417 insertions(+), 395 deletions(-) create mode 100644 app/stac_api/serializers/collection.py diff --git a/app/stac_api/serializers/collection.py b/app/stac_api/serializers/collection.py new file mode 100644 index 00000000..1968e7fb --- /dev/null +++ b/app/stac_api/serializers/collection.py @@ -0,0 +1,413 @@ +import logging + +from django.contrib.gis.geos import GEOSGeometry + +from rest_framework import serializers + +from stac_api.models import Collection +from stac_api.models import CollectionAsset +from stac_api.models import CollectionLink +from stac_api.models import Provider +from stac_api.serializers.serializers import AssetsDictSerializer +from stac_api.serializers.serializers import HrefField +from stac_api.serializers.serializers_utils import NonNullModelSerializer +from stac_api.serializers.serializers_utils import UpsertModelSerializerMixin +from stac_api.serializers.serializers_utils import get_relation_links +from stac_api.serializers.serializers_utils import update_or_create_links +from stac_api.utils import get_stac_version +from stac_api.utils import is_api_version_1 +from stac_api.utils import isoformat +from stac_api.validators import normalize_and_validate_media_type +from stac_api.validators import validate_asset_name +from stac_api.validators import validate_asset_name_with_media_type +from stac_api.validators import validate_name +from stac_api.validators_serializer import validate_json_payload +from stac_api.validators_serializer import validate_uniqueness_and_create + +logger = logging.getLogger(__name__) + + +class ProviderSerializer(NonNullModelSerializer): + + class Meta: + model = Provider + fields = ['name', 'roles', 'url', 'description'] + + +class CollectionLinkSerializer(NonNullModelSerializer): + + class Meta: + model = CollectionLink + fields = ['href', 'rel', 'title', 'type'] + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + type = serializers.CharField( + required=False, allow_blank=True, max_length=150, source="link_type" + ) + + +class CollectionAssetBaseSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): + '''Collection asset serializer base class + ''' + + class Meta: + model = CollectionAsset + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'roles', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated', + ] + validators = [] # Remove a default "unique together" constraint. + # (see: + # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField(source='name', max_length=255, validators=[validate_asset_name]) + title = serializers.CharField( + required=False, max_length=255, allow_null=True, allow_blank=False + ) + description = serializers.CharField(required=False, allow_blank=False, allow_null=True) + # Can't be a ChoiceField, as the validate method normalizes the MIME string only after it + # is read. Consistency is nevertheless guaranteed by the validate() and validate_type() methods. + type = serializers.CharField( + source='media_type', required=True, allow_null=False, allow_blank=False + ) + # Here we need to explicitely define these fields with the source, because they are renamed + # in the get_fields() method + proj_epsg = serializers.IntegerField(source='proj_epsg', allow_null=True, required=False) + # read only fields + checksum_multihash = serializers.CharField(source='checksum_multihash', read_only=True) + href = HrefField(source='file', read_only=True) + created = serializers.DateTimeField(read_only=True) + updated = serializers.DateTimeField(read_only=True) + + # helper variable to provide the collection for upsert validation + # see views.AssetDetail.perform_upsert + collection = None + + def create(self, validated_data): + asset = validate_uniqueness_and_create(CollectionAsset, validated_data) + return asset + + def update_or_create(self, look_up, validated_data): + """ + Update or create the asset object selected by kwargs and return the instance. + When no asset object matching the kwargs selection, a new asset is created. + Args: + validated_data: dict + Copy of the validated_data to use for update + kwargs: dict + Object selection arguments (NOTE: the selection arguments must match a unique + object in DB otherwise an IntegrityError will be raised) + Returns: tuple + Asset instance and True if created otherwise false + """ + asset, created = CollectionAsset.objects.update_or_create( + **look_up, + defaults=validated_data, + ) + return asset, created + + def validate_type(self, value): + ''' Validates the the field "type" + ''' + return normalize_and_validate_media_type(value) + + def validate(self, attrs): + name = attrs['name'] if not self.partial else attrs.get('name', self.instance.name) + media_type = attrs['media_type'] if not self.partial else attrs.get( + 'media_type', self.instance.media_type + ) + validate_asset_name_with_media_type(name, media_type) + + validate_json_payload(self) + + return attrs + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['proj:epsg'] = fields.pop('proj_epsg') + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + fields.pop('roles', None) + + return fields + + +class CollectionAssetSerializer(CollectionAssetBaseSerializer): + '''Collection Asset serializer for the collection asset views + + This serializer adds the links list attribute. + ''' + + def to_representation(self, instance): + collection = instance.collection.name + name = instance.name + request = self.context.get("request") + representation = super().to_representation(instance) + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'] = get_relation_links( + request, 'collection-asset-detail', [collection, name] + ) + return representation + + +class CollectionAssetsForCollectionSerializer(CollectionAssetBaseSerializer): + '''Collection assets serializer for nesting them inside the collection + + Assets should be nested inside their collection but using a dictionary instead of a list and + without links. + ''' + + class Meta: + model = CollectionAsset + list_serializer_class = AssetsDictSerializer + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'roles', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated' + ] + + +class CollectionSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): + + class Meta: + model = Collection + fields = [ + 'published', + 'stac_version', + 'stac_extensions', + 'id', + 'title', + 'description', + 'summaries', + 'extent', + 'providers', + 'license', + 'created', + 'updated', + 'links', + 'crs', + 'itemType', + 'assets' + ] + # crs not in sample data, but in specs.. + validators = [] # Remove a default "unique together" constraint. + # (see: + # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) + + published = serializers.BooleanField(write_only=True, default=True) + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField( + required=True, max_length=255, source="name", validators=[validate_name] + ) + title = serializers.CharField(required=False, allow_blank=False, default=None, max_length=255) + # Also links are required in the spec, the main links (self, root, items) are automatically + # generated hence here it is set to required=False which allows to add optional links that + # are not generated + links = CollectionLinkSerializer(required=False, many=True) + providers = ProviderSerializer(required=False, many=True) + + # read only fields + crs = serializers.SerializerMethodField(read_only=True) + created = serializers.DateTimeField(read_only=True) + updated = serializers.DateTimeField(read_only=True) + extent = serializers.SerializerMethodField(read_only=True) + summaries = serializers.SerializerMethodField(read_only=True) + stac_extensions = serializers.SerializerMethodField(read_only=True) + stac_version = serializers.SerializerMethodField(read_only=True) + itemType = serializers.ReadOnlyField(default="Feature") # pylint: disable=invalid-name + assets = CollectionAssetsForCollectionSerializer(many=True, read_only=True) + + def get_crs(self, obj): + return ["http://www.opengis.net/def/crs/OGC/1.3/CRS84"] + + def get_stac_extensions(self, obj): + return [] + + def get_stac_version(self, obj): + return get_stac_version(self.context.get('request')) + + def get_summaries(self, obj): + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + return { + 'proj:epsg': obj.summaries_proj_epsg or [], + 'eo:gsd': obj.summaries_eo_gsd or [], + 'geoadmin:variant': obj.summaries_geoadmin_variant or [], + 'geoadmin:lang': obj.summaries_geoadmin_lang or [] + } + return { + 'proj:epsg': obj.summaries_proj_epsg or [], + 'gsd': obj.summaries_eo_gsd or [], + 'geoadmin:variant': obj.summaries_geoadmin_variant or [], + 'geoadmin:lang': obj.summaries_geoadmin_lang or [] + } + + def get_extent(self, obj): + start = obj.extent_start_datetime + end = obj.extent_end_datetime + if start is not None: + start = isoformat(start) + if end is not None: + end = isoformat(end) + + bbox = [0, 0, 0, 0] + if obj.extent_geometry is not None: + bbox = list(GEOSGeometry(obj.extent_geometry).extent) + + return { + "spatial": { + "bbox": [bbox] + }, + "temporal": { + "interval": [[start, end]] + }, + } + + def _update_or_create_providers(self, collection, providers_data): + provider_ids = [] + for provider_data in providers_data: + provider, created = Provider.objects.get_or_create( + collection=collection, + name=provider_data["name"], + defaults={ + 'description': provider_data.get('description', None), + 'roles': provider_data.get('roles', None), + 'url': provider_data.get('url', None) + } + ) + logger.debug( + '%s provider %s', + 'created' if created else 'updated', + provider.name, + extra={"provider": provider_data} + ) + provider_ids.append(provider.id) + # the duplicate here is necessary to update the values in + # case the object already exists + provider.description = provider_data.get('description', provider.description) + provider.roles = provider_data.get('roles', provider.roles) + provider.url = provider_data.get('url', provider.url) + provider.full_clean() + provider.save() + + # Delete providers that were not mentioned in the payload anymore + deleted = Provider.objects.filter(collection=collection).exclude(id__in=provider_ids + ).delete() + logger.info( + "deleted %d stale providers for collection %s", + deleted[0], + collection.name, + extra={"collection": collection.name} + ) + + def create(self, validated_data): + """ + Create and return a new `Collection` instance, given the validated data. + """ + providers_data = validated_data.pop('providers', []) + links_data = validated_data.pop('links', []) + collection = validate_uniqueness_and_create(Collection, validated_data) + self._update_or_create_providers(collection=collection, providers_data=providers_data) + update_or_create_links( + instance_type="collection", + model=CollectionLink, + instance=collection, + links_data=links_data + ) + return collection + + def update(self, instance, validated_data): + """ + Update and return an existing `Collection` instance, given the validated data. + """ + providers_data = validated_data.pop('providers', []) + links_data = validated_data.pop('links', []) + self._update_or_create_providers(collection=instance, providers_data=providers_data) + update_or_create_links( + instance_type="collection", + model=CollectionLink, + instance=instance, + links_data=links_data + ) + return super().update(instance, validated_data) + + def update_or_create(self, look_up, validated_data): + """ + Update or create the collection object selected by kwargs and return the instance. + + When no collection object matching the kwargs selection, a new object is created. + + Args: + validated_data: dict + Copy of the validated_data to use for update + kwargs: dict + Object selection arguments (NOTE: the selection arguments must match a unique + object in DB otherwise an IntegrityError will be raised) + + Returns: tuple + Collection instance and True if created otherwise false + """ + providers_data = validated_data.pop('providers', []) + links_data = validated_data.pop('links', []) + collection, created = Collection.objects.update_or_create( + **look_up, defaults=validated_data + ) + self._update_or_create_providers(collection=collection, providers_data=providers_data) + update_or_create_links( + instance_type="collection", + model=CollectionLink, + instance=collection, + links_data=links_data + ) + return collection, created + + def to_representation(self, instance): + name = instance.name + request = self.context.get("request") + representation = super().to_representation(instance) + + # Add hardcoded value Collection to response to conform to stac spec v1. + representation['type'] = "Collection" + + # Remove property on older versions + if not is_api_version_1(request): + del representation['type'] + + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'][:0] = get_relation_links(request, 'collection-detail', [name]) + return representation + + def validate(self, attrs): + validate_json_payload(self) + return attrs diff --git a/app/stac_api/serializers/serializers.py b/app/stac_api/serializers/serializers.py index 30f0611b..3779dacc 100644 --- a/app/stac_api/serializers/serializers.py +++ b/app/stac_api/serializers/serializers.py @@ -3,7 +3,6 @@ from urllib.parse import urlparse from django.conf import settings -from django.contrib.gis.geos import GEOSGeometry from django.core.exceptions import ValidationError as CoreValidationError from django.utils.translation import gettext_lazy as _ @@ -14,15 +13,11 @@ from stac_api.models import Asset from stac_api.models import AssetUpload -from stac_api.models import Collection -from stac_api.models import CollectionAsset from stac_api.models import CollectionAssetUpload -from stac_api.models import CollectionLink from stac_api.models import Item from stac_api.models import ItemLink from stac_api.models import LandingPage from stac_api.models import LandingPageLink -from stac_api.models import Provider from stac_api.serializers.serializers_utils import DictSerializer from stac_api.serializers.serializers_utils import NonNullModelSerializer from stac_api.serializers.serializers_utils import UpsertModelSerializerMixin @@ -161,26 +156,6 @@ def to_representation(self, instance): return representation -class ProviderSerializer(NonNullModelSerializer): - - class Meta: - model = Provider - fields = ['name', 'roles', 'url', 'description'] - - -class CollectionLinkSerializer(NonNullModelSerializer): - - class Meta: - model = CollectionLink - fields = ['href', 'rel', 'title', 'type'] - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - type = serializers.CharField( - required=False, allow_blank=True, max_length=150, source="link_type" - ) - - class ItemLinkSerializer(NonNullModelSerializer): class Meta: @@ -382,107 +357,6 @@ def get_fields(self): return fields -class CollectionAssetBaseSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): - '''Collection asset serializer base class - ''' - - class Meta: - model = CollectionAsset - fields = [ - 'id', - 'title', - 'type', - 'href', - 'description', - 'roles', - 'proj_epsg', - 'checksum_multihash', - 'created', - 'updated', - ] - validators = [] # Remove a default "unique together" constraint. - # (see: - # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - id = serializers.CharField(source='name', max_length=255, validators=[validate_asset_name]) - title = serializers.CharField( - required=False, max_length=255, allow_null=True, allow_blank=False - ) - description = serializers.CharField(required=False, allow_blank=False, allow_null=True) - # Can't be a ChoiceField, as the validate method normalizes the MIME string only after it - # is read. Consistency is nevertheless guaranteed by the validate() and validate_type() methods. - type = serializers.CharField( - source='media_type', required=True, allow_null=False, allow_blank=False - ) - # Here we need to explicitely define these fields with the source, because they are renamed - # in the get_fields() method - proj_epsg = serializers.IntegerField(source='proj_epsg', allow_null=True, required=False) - # read only fields - checksum_multihash = serializers.CharField(source='checksum_multihash', read_only=True) - href = HrefField(source='file', read_only=True) - created = serializers.DateTimeField(read_only=True) - updated = serializers.DateTimeField(read_only=True) - - # helper variable to provide the collection for upsert validation - # see views.AssetDetail.perform_upsert - collection = None - - def create(self, validated_data): - asset = validate_uniqueness_and_create(CollectionAsset, validated_data) - return asset - - def update_or_create(self, look_up, validated_data): - """ - Update or create the asset object selected by kwargs and return the instance. - When no asset object matching the kwargs selection, a new asset is created. - Args: - validated_data: dict - Copy of the validated_data to use for update - kwargs: dict - Object selection arguments (NOTE: the selection arguments must match a unique - object in DB otherwise an IntegrityError will be raised) - Returns: tuple - Asset instance and True if created otherwise false - """ - asset, created = CollectionAsset.objects.update_or_create( - **look_up, - defaults=validated_data, - ) - return asset, created - - def validate_type(self, value): - ''' Validates the the field "type" - ''' - return normalize_and_validate_media_type(value) - - def validate(self, attrs): - name = attrs['name'] if not self.partial else attrs.get('name', self.instance.name) - media_type = attrs['media_type'] if not self.partial else attrs.get( - 'media_type', self.instance.media_type - ) - validate_asset_name_with_media_type(name, media_type) - - validate_json_payload(self) - - return attrs - - def get_fields(self): - fields = super().get_fields() - # This is a hack to allow fields with special characters - fields['proj:epsg'] = fields.pop('proj_epsg') - fields['file:checksum'] = fields.pop('checksum_multihash') - - # Older versions of the api still use different name - request = self.context.get('request') - if not is_api_version_1(request): - fields['checksum:multihash'] = fields.pop('file:checksum') - fields.pop('roles', None) - - return fields - - class AssetSerializer(AssetBaseSerializer): '''Asset serializer for the asset views @@ -538,27 +412,6 @@ def validate(self, attrs): return super().validate(attrs) -class CollectionAssetSerializer(CollectionAssetBaseSerializer): - '''Collection Asset serializer for the collection asset views - - This serializer adds the links list attribute. - ''' - - def to_representation(self, instance): - collection = instance.collection.name - name = instance.name - request = self.context.get("request") - representation = super().to_representation(instance) - # Add auto links - # We use OrderedDict, although it is not necessary, because the default serializer/model for - # links already uses OrderedDict, this way we keep consistency between auto link and user - # link - representation['links'] = get_relation_links( - request, 'collection-asset-detail', [collection, name] - ) - return representation - - class AssetsForItemSerializer(AssetBaseSerializer): '''Assets serializer for nesting them inside the item @@ -586,250 +439,6 @@ class Meta: ] -class CollectionAssetsForCollectionSerializer(CollectionAssetBaseSerializer): - '''Collection assets serializer for nesting them inside the collection - - Assets should be nested inside their collection but using a dictionary instead of a list and - without links. - ''' - - class Meta: - model = CollectionAsset - list_serializer_class = AssetsDictSerializer - fields = [ - 'id', - 'title', - 'type', - 'href', - 'description', - 'roles', - 'proj_epsg', - 'checksum_multihash', - 'created', - 'updated' - ] - - -class CollectionSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): - - class Meta: - model = Collection - fields = [ - 'published', - 'stac_version', - 'stac_extensions', - 'id', - 'title', - 'description', - 'summaries', - 'extent', - 'providers', - 'license', - 'created', - 'updated', - 'links', - 'crs', - 'itemType', - 'assets' - ] - # crs not in sample data, but in specs.. - validators = [] # Remove a default "unique together" constraint. - # (see: - # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) - - published = serializers.BooleanField(write_only=True, default=True) - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - id = serializers.CharField( - required=True, max_length=255, source="name", validators=[validate_name] - ) - title = serializers.CharField(required=False, allow_blank=False, default=None, max_length=255) - # Also links are required in the spec, the main links (self, root, items) are automatically - # generated hence here it is set to required=False which allows to add optional links that - # are not generated - links = CollectionLinkSerializer(required=False, many=True) - providers = ProviderSerializer(required=False, many=True) - - # read only fields - crs = serializers.SerializerMethodField(read_only=True) - created = serializers.DateTimeField(read_only=True) - updated = serializers.DateTimeField(read_only=True) - extent = serializers.SerializerMethodField(read_only=True) - summaries = serializers.SerializerMethodField(read_only=True) - stac_extensions = serializers.SerializerMethodField(read_only=True) - stac_version = serializers.SerializerMethodField(read_only=True) - itemType = serializers.ReadOnlyField(default="Feature") # pylint: disable=invalid-name - assets = CollectionAssetsForCollectionSerializer(many=True, read_only=True) - - def get_crs(self, obj): - return ["http://www.opengis.net/def/crs/OGC/1.3/CRS84"] - - def get_stac_extensions(self, obj): - return [] - - def get_stac_version(self, obj): - return get_stac_version(self.context.get('request')) - - def get_summaries(self, obj): - # Older versions of the api still use different name - request = self.context.get('request') - if not is_api_version_1(request): - return { - 'proj:epsg': obj.summaries_proj_epsg or [], - 'eo:gsd': obj.summaries_eo_gsd or [], - 'geoadmin:variant': obj.summaries_geoadmin_variant or [], - 'geoadmin:lang': obj.summaries_geoadmin_lang or [] - } - return { - 'proj:epsg': obj.summaries_proj_epsg or [], - 'gsd': obj.summaries_eo_gsd or [], - 'geoadmin:variant': obj.summaries_geoadmin_variant or [], - 'geoadmin:lang': obj.summaries_geoadmin_lang or [] - } - - def get_extent(self, obj): - start = obj.extent_start_datetime - end = obj.extent_end_datetime - if start is not None: - start = isoformat(start) - if end is not None: - end = isoformat(end) - - bbox = [0, 0, 0, 0] - if obj.extent_geometry is not None: - bbox = list(GEOSGeometry(obj.extent_geometry).extent) - - return { - "spatial": { - "bbox": [bbox] - }, - "temporal": { - "interval": [[start, end]] - }, - } - - def _update_or_create_providers(self, collection, providers_data): - provider_ids = [] - for provider_data in providers_data: - provider, created = Provider.objects.get_or_create( - collection=collection, - name=provider_data["name"], - defaults={ - 'description': provider_data.get('description', None), - 'roles': provider_data.get('roles', None), - 'url': provider_data.get('url', None) - } - ) - logger.debug( - '%s provider %s', - 'created' if created else 'updated', - provider.name, - extra={"provider": provider_data} - ) - provider_ids.append(provider.id) - # the duplicate here is necessary to update the values in - # case the object already exists - provider.description = provider_data.get('description', provider.description) - provider.roles = provider_data.get('roles', provider.roles) - provider.url = provider_data.get('url', provider.url) - provider.full_clean() - provider.save() - - # Delete providers that were not mentioned in the payload anymore - deleted = Provider.objects.filter(collection=collection).exclude(id__in=provider_ids - ).delete() - logger.info( - "deleted %d stale providers for collection %s", - deleted[0], - collection.name, - extra={"collection": collection.name} - ) - - def create(self, validated_data): - """ - Create and return a new `Collection` instance, given the validated data. - """ - providers_data = validated_data.pop('providers', []) - links_data = validated_data.pop('links', []) - collection = validate_uniqueness_and_create(Collection, validated_data) - self._update_or_create_providers(collection=collection, providers_data=providers_data) - update_or_create_links( - instance_type="collection", - model=CollectionLink, - instance=collection, - links_data=links_data - ) - return collection - - def update(self, instance, validated_data): - """ - Update and return an existing `Collection` instance, given the validated data. - """ - providers_data = validated_data.pop('providers', []) - links_data = validated_data.pop('links', []) - self._update_or_create_providers(collection=instance, providers_data=providers_data) - update_or_create_links( - instance_type="collection", - model=CollectionLink, - instance=instance, - links_data=links_data - ) - return super().update(instance, validated_data) - - def update_or_create(self, look_up, validated_data): - """ - Update or create the collection object selected by kwargs and return the instance. - - When no collection object matching the kwargs selection, a new object is created. - - Args: - validated_data: dict - Copy of the validated_data to use for update - kwargs: dict - Object selection arguments (NOTE: the selection arguments must match a unique - object in DB otherwise an IntegrityError will be raised) - - Returns: tuple - Collection instance and True if created otherwise false - """ - providers_data = validated_data.pop('providers', []) - links_data = validated_data.pop('links', []) - collection, created = Collection.objects.update_or_create( - **look_up, defaults=validated_data - ) - self._update_or_create_providers(collection=collection, providers_data=providers_data) - update_or_create_links( - instance_type="collection", - model=CollectionLink, - instance=collection, - links_data=links_data - ) - return collection, created - - def to_representation(self, instance): - name = instance.name - request = self.context.get("request") - representation = super().to_representation(instance) - - # Add hardcoded value Collection to response to conform to stac spec v1. - representation['type'] = "Collection" - - # Remove property on older versions - if not is_api_version_1(request): - del representation['type'] - - # Add auto links - # We use OrderedDict, although it is not necessary, because the default serializer/model for - # links already uses OrderedDict, this way we keep consistency between auto link and user - # link - representation['links'][:0] = get_relation_links(request, 'collection-detail', [name]) - return representation - - def validate(self, attrs): - validate_json_payload(self) - return attrs - - class ItemSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): class Meta: diff --git a/app/stac_api/views/collection.py b/app/stac_api/views/collection.py index 6cbb9b56..2986300e 100644 --- a/app/stac_api/views/collection.py +++ b/app/stac_api/views/collection.py @@ -10,8 +10,8 @@ from stac_api.models import Collection from stac_api.models import CollectionAsset -from stac_api.serializers.serializers import CollectionAssetSerializer -from stac_api.serializers.serializers import CollectionSerializer +from stac_api.serializers.collection import CollectionAssetSerializer +from stac_api.serializers.collection import CollectionSerializer from stac_api.serializers.serializers_utils import get_relation_links from stac_api.utils import get_collection_asset_path from stac_api.validators_view import validate_collection diff --git a/app/tests/tests_09/test_serializer.py b/app/tests/tests_09/test_serializer.py index 0ec7f4a0..9953e2e9 100644 --- a/app/tests/tests_09/test_serializer.py +++ b/app/tests/tests_09/test_serializer.py @@ -13,8 +13,8 @@ from rest_framework.test import APIRequestFactory from stac_api.models import get_asset_path +from stac_api.serializers.collection import CollectionSerializer from stac_api.serializers.serializers import AssetSerializer -from stac_api.serializers.serializers import CollectionSerializer from stac_api.serializers.serializers import ItemSerializer from stac_api.utils import get_link from stac_api.utils import isoformat diff --git a/app/tests/tests_10/test_serializer.py b/app/tests/tests_10/test_serializer.py index aaa6aa30..30c8a889 100644 --- a/app/tests/tests_10/test_serializer.py +++ b/app/tests/tests_10/test_serializer.py @@ -13,8 +13,8 @@ from rest_framework.test import APIRequestFactory from stac_api.models import get_asset_path +from stac_api.serializers.collection import CollectionSerializer from stac_api.serializers.serializers import AssetSerializer -from stac_api.serializers.serializers import CollectionSerializer from stac_api.serializers.serializers import ItemSerializer from stac_api.utils import get_link from stac_api.utils import isoformat From 0bac3fa57308c1e230e356b55f5ec6ad80443a36 Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Wed, 11 Sep 2024 08:06:26 +0200 Subject: [PATCH 11/16] Move item serializers to own file --- .../commands/profile_item_serializer.py | 2 +- .../commands/profile_serializer_vs_no_drf.py | 2 +- app/stac_api/serializers/item.py | 406 ++++++++++++++++++ app/stac_api/serializers/serializers.py | 392 ----------------- app/stac_api/views/item.py | 4 +- app/stac_api/views/views.py | 2 +- app/tests/tests_09/test_serializer.py | 4 +- app/tests/tests_10/test_serializer.py | 4 +- 8 files changed, 415 insertions(+), 401 deletions(-) create mode 100644 app/stac_api/serializers/item.py diff --git a/app/stac_api/management/commands/profile_item_serializer.py b/app/stac_api/management/commands/profile_item_serializer.py index b348b175..47f2ba79 100644 --- a/app/stac_api/management/commands/profile_item_serializer.py +++ b/app/stac_api/management/commands/profile_item_serializer.py @@ -20,7 +20,7 @@ class Handler(CommandHandler): def profiling(self): # pylint: disable=import-outside-toplevel,possibly-unused-variable - from stac_api.serializers.serializers import ItemSerializer + from stac_api.serializers.item import ItemSerializer collection_id = self.options["collection"] qs = Item.objects.filter(collection__name=collection_id ).prefetch_related('assets', 'links')[:self.options['limit']] diff --git a/app/stac_api/management/commands/profile_serializer_vs_no_drf.py b/app/stac_api/management/commands/profile_serializer_vs_no_drf.py index 6d580a6c..d4b1ddca 100644 --- a/app/stac_api/management/commands/profile_serializer_vs_no_drf.py +++ b/app/stac_api/management/commands/profile_serializer_vs_no_drf.py @@ -21,7 +21,7 @@ def profiling(self): # pylint: disable=import-outside-toplevel,possibly-unused-variable self.print('Starting profiling') - from stac_api.serializers.serializers import ItemSerializer + from stac_api.serializers.item import ItemSerializer def serialize(qs): return { diff --git a/app/stac_api/serializers/item.py b/app/stac_api/serializers/item.py new file mode 100644 index 00000000..7c545855 --- /dev/null +++ b/app/stac_api/serializers/item.py @@ -0,0 +1,406 @@ +import logging + +from django.core.exceptions import ValidationError as CoreValidationError +from django.utils.translation import gettext_lazy as _ + +from rest_framework import serializers +from rest_framework_gis import serializers as gis_serializers + +from stac_api.models import Asset +from stac_api.models import Item +from stac_api.models import ItemLink +from stac_api.serializers.serializers import AssetsDictSerializer +from stac_api.serializers.serializers import HrefField +from stac_api.serializers.serializers_utils import NonNullModelSerializer +from stac_api.serializers.serializers_utils import UpsertModelSerializerMixin +from stac_api.serializers.serializers_utils import get_relation_links +from stac_api.serializers.serializers_utils import update_or_create_links +from stac_api.utils import get_stac_version +from stac_api.utils import is_api_version_1 +from stac_api.validators import normalize_and_validate_media_type +from stac_api.validators import validate_asset_name +from stac_api.validators import validate_asset_name_with_media_type +from stac_api.validators import validate_geoadmin_variant +from stac_api.validators import validate_href_url +from stac_api.validators import validate_item_properties_datetimes +from stac_api.validators import validate_name +from stac_api.validators_serializer import validate_json_payload +from stac_api.validators_serializer import validate_uniqueness_and_create + +logger = logging.getLogger(__name__) + + +class BboxSerializer(gis_serializers.GeoFeatureModelSerializer): + + class Meta: + model = Item + geo_field = "geometry" + auto_bbox = True + fields = ['geometry'] + + def to_representation(self, instance): + python_native = super().to_representation(instance) + return python_native['bbox'] + + +class ItemLinkSerializer(NonNullModelSerializer): + + class Meta: + model = ItemLink + fields = ['href', 'rel', 'title', 'type'] + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + type = serializers.CharField( + required=False, allow_blank=True, max_length=255, source="link_type" + ) + + +class ItemsPropertiesSerializer(serializers.Serializer): + # pylint: disable=abstract-method + # ItemsPropertiesSerializer is a nested serializer and don't directly create/write instances + # therefore we don't need to implement the super method create() and update() + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + datetime = serializers.DateTimeField(source='properties_datetime', required=False, default=None) + start_datetime = serializers.DateTimeField( + source='properties_start_datetime', required=False, default=None + ) + end_datetime = serializers.DateTimeField( + source='properties_end_datetime', required=False, default=None + ) + title = serializers.CharField( + source='properties_title', + required=False, + allow_blank=False, + allow_null=True, + max_length=255, + default=None + ) + created = serializers.DateTimeField(read_only=True) + updated = serializers.DateTimeField(read_only=True) + expires = serializers.DateTimeField(source='properties_expires', required=False, default=None) + + +class AssetBaseSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): + '''Asset serializer base class + ''' + + class Meta: + model = Asset + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'eo_gsd', + 'roles', + 'geoadmin_lang', + 'geoadmin_variant', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated', + ] + validators = [] # Remove a default "unique together" constraint. + # (see: + # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField(source='name', max_length=255, validators=[validate_asset_name]) + title = serializers.CharField( + required=False, max_length=255, allow_null=True, allow_blank=False + ) + description = serializers.CharField(required=False, allow_blank=False, allow_null=True) + # Can't be a ChoiceField, as the validate method normalizes the MIME string only after it + # is read. Consistency is nevertheless guaranteed by the validate() and validate_type() methods. + type = serializers.CharField( + source='media_type', required=True, allow_null=False, allow_blank=False + ) + # Here we need to explicitely define these fields with the source, because they are renamed + # in the get_fields() method + eo_gsd = serializers.FloatField(source='eo_gsd', required=False, allow_null=True) + geoadmin_lang = serializers.ChoiceField( + source='geoadmin_lang', + choices=Asset.Language.values, + required=False, + allow_null=True, + allow_blank=False + ) + geoadmin_variant = serializers.CharField( + source='geoadmin_variant', + max_length=25, + allow_blank=False, + allow_null=True, + required=False, + validators=[validate_geoadmin_variant] + ) + proj_epsg = serializers.IntegerField(source='proj_epsg', allow_null=True, required=False) + # read only fields + checksum_multihash = serializers.CharField(source='checksum_multihash', read_only=True) + href = HrefField(source='file', required=False) + created = serializers.DateTimeField(read_only=True) + updated = serializers.DateTimeField(read_only=True) + + # helper variable to provide the collection for upsert validation + # see views.AssetDetail.perform_upsert + collection = None + + def create(self, validated_data): + asset = validate_uniqueness_and_create(Asset, validated_data) + return asset + + def update_or_create(self, look_up, validated_data): + """ + Update or create the asset object selected by kwargs and return the instance. + When no asset object matching the kwargs selection, a new asset is created. + Args: + validated_data: dict + Copy of the validated_data to use for update + kwargs: dict + Object selection arguments (NOTE: the selection arguments must match a unique + object in DB otherwise an IntegrityError will be raised) + Returns: tuple + Asset instance and True if created otherwise false + """ + asset, created = Asset.objects.update_or_create(**look_up, defaults=validated_data) + return asset, created + + def validate_type(self, value): + ''' Validates the the field "type" + ''' + return normalize_and_validate_media_type(value) + + def validate(self, attrs): + name = attrs['name'] if not self.partial else attrs.get('name', self.instance.name) + media_type = attrs['media_type'] if not self.partial else attrs.get( + 'media_type', self.instance.media_type + ) + validate_asset_name_with_media_type(name, media_type) + + validate_json_payload(self) + + return attrs + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['gsd'] = fields.pop('eo_gsd') + fields['proj:epsg'] = fields.pop('proj_epsg') + fields['geoadmin:variant'] = fields.pop('geoadmin_variant') + fields['geoadmin:lang'] = fields.pop('geoadmin_lang') + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + fields['eo:gsd'] = fields.pop('gsd') + fields.pop('roles', None) + + return fields + + +class AssetSerializer(AssetBaseSerializer): + '''Asset serializer for the asset views + + This serializer adds the links list attribute. + ''' + + def to_representation(self, instance): + collection = instance.item.collection.name + item = instance.item.name + name = instance.name + request = self.context.get("request") + representation = super().to_representation(instance) + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'] = get_relation_links( + request, 'asset-detail', [collection, item, name] + ) + return representation + + def _validate_href_field(self, attrs): + """Only allow the href field if the collection allows for external assets + + Raise an exception, this replicates the previous behaviour when href + was always read_only + """ + # the href field is translated to the file field here + if 'file' in attrs: + if self.collection: + collection = self.collection + else: + raise LookupError("No collection defined.") + + if not collection.allow_external_assets: + logger.info( + 'Attempted external asset upload with no permission', + extra={ + 'collection': self.collection, 'attrs': attrs + } + ) + errors = {'href': _("Found read-only property in payload")} + raise serializers.ValidationError(code="payload", detail=errors) + + try: + validate_href_url(attrs['file'], collection) + except CoreValidationError as e: + errors = {'href': e.message} + raise serializers.ValidationError(code='payload', detail=errors) + + def validate(self, attrs): + self._validate_href_field(attrs) + return super().validate(attrs) + + +class AssetsForItemSerializer(AssetBaseSerializer): + '''Assets serializer for nesting them inside the item + + Assets should be nested inside their item but using a dictionary instead of a list and without + links. + ''' + + class Meta: + model = Asset + list_serializer_class = AssetsDictSerializer + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'roles', + 'eo_gsd', + 'geoadmin_lang', + 'geoadmin_variant', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated' + ] + + +class ItemSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): + + class Meta: + model = Item + fields = [ + 'id', + 'collection', + 'type', + 'stac_version', + 'geometry', + 'bbox', + 'properties', + 'stac_extensions', + 'links', + 'assets' + ] + validators = [] # Remove a default "unique together" constraint. + # (see: + # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField( + source='name', required=True, max_length=255, validators=[validate_name] + ) + properties = ItemsPropertiesSerializer(source='*', required=True) + geometry = gis_serializers.GeometryField(required=True) + links = ItemLinkSerializer(required=False, many=True) + # read only fields + type = serializers.SerializerMethodField() + collection = serializers.SlugRelatedField(slug_field='name', read_only=True) + bbox = BboxSerializer(source='*', read_only=True) + assets = AssetsForItemSerializer(many=True, read_only=True) + stac_extensions = serializers.SerializerMethodField() + stac_version = serializers.SerializerMethodField() + + def get_type(self, obj): + return 'Feature' + + def get_stac_extensions(self, obj): + return [] + + def get_stac_version(self, obj): + return get_stac_version(self.context.get('request')) + + def to_representation(self, instance): + collection = instance.collection.name + name = instance.name + request = self.context.get("request") + representation = super().to_representation(instance) + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'][:0] = get_relation_links(request, 'item-detail', [collection, name]) + representation['stac_extensions'] = [ + # Extension provides schema for the 'expires' timestamp + "https://stac-extensions.github.io/timestamps/v1.1.0/schema.json" + ] + return representation + + def create(self, validated_data): + links_data = validated_data.pop('links', []) + item = validate_uniqueness_and_create(Item, validated_data) + update_or_create_links( + instance_type="item", model=ItemLink, instance=item, links_data=links_data + ) + return item + + def update(self, instance, validated_data): + links_data = validated_data.pop('links', []) + update_or_create_links( + instance_type="item", model=ItemLink, instance=instance, links_data=links_data + ) + return super().update(instance, validated_data) + + def update_or_create(self, look_up, validated_data): + """ + Update or create the item object selected by kwargs and return the instance. + When no item object matching the kwargs selection, a new item is created. + Args: + validated_data: dict + Copy of the validated_data to use for update + kwargs: dict + Object selection arguments (NOTE: the selection arguments must match a unique + object in DB otherwise an IntegrityError will be raised) + Returns: tuple + Item instance and True if created otherwise false + """ + links_data = validated_data.pop('links', []) + item, created = Item.objects.update_or_create(**look_up, defaults=validated_data) + update_or_create_links( + instance_type="item", model=ItemLink, instance=item, links_data=links_data + ) + return item, created + + def validate(self, attrs): + if ( + not self.partial or \ + 'properties_datetime' in attrs or \ + 'properties_start_datetime' in attrs or \ + 'properties_end_datetime' in attrs or \ + 'properties_expires' in attrs + ): + validate_item_properties_datetimes( + attrs.get('properties_datetime', None), + attrs.get('properties_start_datetime', None), + attrs.get('properties_end_datetime', None), + attrs.get('properties_expires', None) + ) + else: + logger.info( + 'Skip validation of item properties datetimes; partial update without datetimes' + ) + + validate_json_payload(self) + + return attrs diff --git a/app/stac_api/serializers/serializers.py b/app/stac_api/serializers/serializers.py index 3779dacc..ecb87a3e 100644 --- a/app/stac_api/serializers/serializers.py +++ b/app/stac_api/serializers/serializers.py @@ -3,44 +3,28 @@ from urllib.parse import urlparse from django.conf import settings -from django.core.exceptions import ValidationError as CoreValidationError from django.utils.translation import gettext_lazy as _ from rest_framework import serializers from rest_framework.utils.serializer_helpers import ReturnDict from rest_framework.validators import UniqueValidator -from rest_framework_gis import serializers as gis_serializers -from stac_api.models import Asset from stac_api.models import AssetUpload from stac_api.models import CollectionAssetUpload -from stac_api.models import Item -from stac_api.models import ItemLink from stac_api.models import LandingPage from stac_api.models import LandingPageLink from stac_api.serializers.serializers_utils import DictSerializer from stac_api.serializers.serializers_utils import NonNullModelSerializer -from stac_api.serializers.serializers_utils import UpsertModelSerializerMixin -from stac_api.serializers.serializers_utils import get_relation_links -from stac_api.serializers.serializers_utils import update_or_create_links from stac_api.utils import build_asset_href from stac_api.utils import get_browser_url from stac_api.utils import get_stac_version from stac_api.utils import get_url from stac_api.utils import is_api_version_1 from stac_api.utils import isoformat -from stac_api.validators import normalize_and_validate_media_type -from stac_api.validators import validate_asset_name -from stac_api.validators import validate_asset_name_with_media_type from stac_api.validators import validate_checksum_multihash_sha256 from stac_api.validators import validate_content_encoding -from stac_api.validators import validate_geoadmin_variant -from stac_api.validators import validate_href_url -from stac_api.validators import validate_item_properties_datetimes from stac_api.validators import validate_md5_parts from stac_api.validators import validate_name -from stac_api.validators_serializer import validate_json_payload -from stac_api.validators_serializer import validate_uniqueness_and_create logger = logging.getLogger(__name__) @@ -156,59 +140,6 @@ def to_representation(self, instance): return representation -class ItemLinkSerializer(NonNullModelSerializer): - - class Meta: - model = ItemLink - fields = ['href', 'rel', 'title', 'type'] - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - type = serializers.CharField( - required=False, allow_blank=True, max_length=255, source="link_type" - ) - - -class ItemsPropertiesSerializer(serializers.Serializer): - # pylint: disable=abstract-method - # ItemsPropertiesSerializer is a nested serializer and don't directly create/write instances - # therefore we don't need to implement the super method create() and update() - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - datetime = serializers.DateTimeField(source='properties_datetime', required=False, default=None) - start_datetime = serializers.DateTimeField( - source='properties_start_datetime', required=False, default=None - ) - end_datetime = serializers.DateTimeField( - source='properties_end_datetime', required=False, default=None - ) - title = serializers.CharField( - source='properties_title', - required=False, - allow_blank=False, - allow_null=True, - max_length=255, - default=None - ) - created = serializers.DateTimeField(read_only=True) - updated = serializers.DateTimeField(read_only=True) - expires = serializers.DateTimeField(source='properties_expires', required=False, default=None) - - -class BboxSerializer(gis_serializers.GeoFeatureModelSerializer): - - class Meta: - model = Item - geo_field = "geometry" - auto_bbox = True - fields = ['geometry'] - - def to_representation(self, instance): - python_native = super().to_representation(instance) - return python_native['bbox'] - - class AssetsDictSerializer(DictSerializer): '''Assets serializer list to dictionary @@ -236,329 +167,6 @@ def to_internal_value(self, data): return data -class AssetBaseSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): - '''Asset serializer base class - ''' - - class Meta: - model = Asset - fields = [ - 'id', - 'title', - 'type', - 'href', - 'description', - 'eo_gsd', - 'roles', - 'geoadmin_lang', - 'geoadmin_variant', - 'proj_epsg', - 'checksum_multihash', - 'created', - 'updated', - ] - validators = [] # Remove a default "unique together" constraint. - # (see: - # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - id = serializers.CharField(source='name', max_length=255, validators=[validate_asset_name]) - title = serializers.CharField( - required=False, max_length=255, allow_null=True, allow_blank=False - ) - description = serializers.CharField(required=False, allow_blank=False, allow_null=True) - # Can't be a ChoiceField, as the validate method normalizes the MIME string only after it - # is read. Consistency is nevertheless guaranteed by the validate() and validate_type() methods. - type = serializers.CharField( - source='media_type', required=True, allow_null=False, allow_blank=False - ) - # Here we need to explicitely define these fields with the source, because they are renamed - # in the get_fields() method - eo_gsd = serializers.FloatField(source='eo_gsd', required=False, allow_null=True) - geoadmin_lang = serializers.ChoiceField( - source='geoadmin_lang', - choices=Asset.Language.values, - required=False, - allow_null=True, - allow_blank=False - ) - geoadmin_variant = serializers.CharField( - source='geoadmin_variant', - max_length=25, - allow_blank=False, - allow_null=True, - required=False, - validators=[validate_geoadmin_variant] - ) - proj_epsg = serializers.IntegerField(source='proj_epsg', allow_null=True, required=False) - # read only fields - checksum_multihash = serializers.CharField(source='checksum_multihash', read_only=True) - href = HrefField(source='file', required=False) - created = serializers.DateTimeField(read_only=True) - updated = serializers.DateTimeField(read_only=True) - - # helper variable to provide the collection for upsert validation - # see views.AssetDetail.perform_upsert - collection = None - - def create(self, validated_data): - asset = validate_uniqueness_and_create(Asset, validated_data) - return asset - - def update_or_create(self, look_up, validated_data): - """ - Update or create the asset object selected by kwargs and return the instance. - When no asset object matching the kwargs selection, a new asset is created. - Args: - validated_data: dict - Copy of the validated_data to use for update - kwargs: dict - Object selection arguments (NOTE: the selection arguments must match a unique - object in DB otherwise an IntegrityError will be raised) - Returns: tuple - Asset instance and True if created otherwise false - """ - asset, created = Asset.objects.update_or_create(**look_up, defaults=validated_data) - return asset, created - - def validate_type(self, value): - ''' Validates the the field "type" - ''' - return normalize_and_validate_media_type(value) - - def validate(self, attrs): - name = attrs['name'] if not self.partial else attrs.get('name', self.instance.name) - media_type = attrs['media_type'] if not self.partial else attrs.get( - 'media_type', self.instance.media_type - ) - validate_asset_name_with_media_type(name, media_type) - - validate_json_payload(self) - - return attrs - - def get_fields(self): - fields = super().get_fields() - # This is a hack to allow fields with special characters - fields['gsd'] = fields.pop('eo_gsd') - fields['proj:epsg'] = fields.pop('proj_epsg') - fields['geoadmin:variant'] = fields.pop('geoadmin_variant') - fields['geoadmin:lang'] = fields.pop('geoadmin_lang') - fields['file:checksum'] = fields.pop('checksum_multihash') - - # Older versions of the api still use different name - request = self.context.get('request') - if not is_api_version_1(request): - fields['checksum:multihash'] = fields.pop('file:checksum') - fields['eo:gsd'] = fields.pop('gsd') - fields.pop('roles', None) - - return fields - - -class AssetSerializer(AssetBaseSerializer): - '''Asset serializer for the asset views - - This serializer adds the links list attribute. - ''' - - def to_representation(self, instance): - collection = instance.item.collection.name - item = instance.item.name - name = instance.name - request = self.context.get("request") - representation = super().to_representation(instance) - # Add auto links - # We use OrderedDict, although it is not necessary, because the default serializer/model for - # links already uses OrderedDict, this way we keep consistency between auto link and user - # link - representation['links'] = get_relation_links( - request, 'asset-detail', [collection, item, name] - ) - return representation - - def _validate_href_field(self, attrs): - """Only allow the href field if the collection allows for external assets - - Raise an exception, this replicates the previous behaviour when href - was always read_only - """ - # the href field is translated to the file field here - if 'file' in attrs: - if self.collection: - collection = self.collection - else: - raise LookupError("No collection defined.") - - if not collection.allow_external_assets: - logger.info( - 'Attempted external asset upload with no permission', - extra={ - 'collection': self.collection, 'attrs': attrs - } - ) - errors = {'href': _("Found read-only property in payload")} - raise serializers.ValidationError(code="payload", detail=errors) - - try: - validate_href_url(attrs['file'], collection) - except CoreValidationError as e: - errors = {'href': e.message} - raise serializers.ValidationError(code='payload', detail=errors) - - def validate(self, attrs): - self._validate_href_field(attrs) - return super().validate(attrs) - - -class AssetsForItemSerializer(AssetBaseSerializer): - '''Assets serializer for nesting them inside the item - - Assets should be nested inside their item but using a dictionary instead of a list and without - links. - ''' - - class Meta: - model = Asset - list_serializer_class = AssetsDictSerializer - fields = [ - 'id', - 'title', - 'type', - 'href', - 'description', - 'roles', - 'eo_gsd', - 'geoadmin_lang', - 'geoadmin_variant', - 'proj_epsg', - 'checksum_multihash', - 'created', - 'updated' - ] - - -class ItemSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): - - class Meta: - model = Item - fields = [ - 'id', - 'collection', - 'type', - 'stac_version', - 'geometry', - 'bbox', - 'properties', - 'stac_extensions', - 'links', - 'assets' - ] - validators = [] # Remove a default "unique together" constraint. - # (see: - # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - id = serializers.CharField( - source='name', required=True, max_length=255, validators=[validate_name] - ) - properties = ItemsPropertiesSerializer(source='*', required=True) - geometry = gis_serializers.GeometryField(required=True) - links = ItemLinkSerializer(required=False, many=True) - # read only fields - type = serializers.SerializerMethodField() - collection = serializers.SlugRelatedField(slug_field='name', read_only=True) - bbox = BboxSerializer(source='*', read_only=True) - assets = AssetsForItemSerializer(many=True, read_only=True) - stac_extensions = serializers.SerializerMethodField() - stac_version = serializers.SerializerMethodField() - - def get_type(self, obj): - return 'Feature' - - def get_stac_extensions(self, obj): - return [] - - def get_stac_version(self, obj): - return get_stac_version(self.context.get('request')) - - def to_representation(self, instance): - collection = instance.collection.name - name = instance.name - request = self.context.get("request") - representation = super().to_representation(instance) - # Add auto links - # We use OrderedDict, although it is not necessary, because the default serializer/model for - # links already uses OrderedDict, this way we keep consistency between auto link and user - # link - representation['links'][:0] = get_relation_links(request, 'item-detail', [collection, name]) - representation['stac_extensions'] = [ - # Extension provides schema for the 'expires' timestamp - "https://stac-extensions.github.io/timestamps/v1.1.0/schema.json" - ] - return representation - - def create(self, validated_data): - links_data = validated_data.pop('links', []) - item = validate_uniqueness_and_create(Item, validated_data) - update_or_create_links( - instance_type="item", model=ItemLink, instance=item, links_data=links_data - ) - return item - - def update(self, instance, validated_data): - links_data = validated_data.pop('links', []) - update_or_create_links( - instance_type="item", model=ItemLink, instance=instance, links_data=links_data - ) - return super().update(instance, validated_data) - - def update_or_create(self, look_up, validated_data): - """ - Update or create the item object selected by kwargs and return the instance. - When no item object matching the kwargs selection, a new item is created. - Args: - validated_data: dict - Copy of the validated_data to use for update - kwargs: dict - Object selection arguments (NOTE: the selection arguments must match a unique - object in DB otherwise an IntegrityError will be raised) - Returns: tuple - Item instance and True if created otherwise false - """ - links_data = validated_data.pop('links', []) - item, created = Item.objects.update_or_create(**look_up, defaults=validated_data) - update_or_create_links( - instance_type="item", model=ItemLink, instance=item, links_data=links_data - ) - return item, created - - def validate(self, attrs): - if ( - not self.partial or \ - 'properties_datetime' in attrs or \ - 'properties_start_datetime' in attrs or \ - 'properties_end_datetime' in attrs or \ - 'properties_expires' in attrs - ): - validate_item_properties_datetimes( - attrs.get('properties_datetime', None), - attrs.get('properties_start_datetime', None), - attrs.get('properties_end_datetime', None), - attrs.get('properties_expires', None) - ) - else: - logger.info( - 'Skip validation of item properties datetimes; partial update without datetimes' - ) - - validate_json_payload(self) - - return attrs - - class AssetUploadListSerializer(serializers.ListSerializer): # pylint: disable=abstract-method diff --git a/app/stac_api/views/item.py b/app/stac_api/views/item.py index a760ec37..6d9469b6 100644 --- a/app/stac_api/views/item.py +++ b/app/stac_api/views/item.py @@ -14,8 +14,8 @@ from stac_api.models import Asset from stac_api.models import Collection from stac_api.models import Item -from stac_api.serializers.serializers import AssetSerializer -from stac_api.serializers.serializers import ItemSerializer +from stac_api.serializers.item import AssetSerializer +from stac_api.serializers.item import ItemSerializer from stac_api.serializers.serializers_utils import get_relation_links from stac_api.utils import get_asset_path from stac_api.utils import utc_aware diff --git a/app/stac_api/views/views.py b/app/stac_api/views/views.py index d74d4195..bda80893 100644 --- a/app/stac_api/views/views.py +++ b/app/stac_api/views/views.py @@ -33,11 +33,11 @@ from stac_api.pagination import ExtApiPagination from stac_api.pagination import GetPostCursorPagination from stac_api.s3_multipart_upload import MultipartUpload +from stac_api.serializers.item import ItemSerializer from stac_api.serializers.serializers import AssetUploadPartsSerializer from stac_api.serializers.serializers import AssetUploadSerializer from stac_api.serializers.serializers import CollectionAssetUploadSerializer from stac_api.serializers.serializers import ConformancePageSerializer -from stac_api.serializers.serializers import ItemSerializer from stac_api.serializers.serializers import LandingPageSerializer from stac_api.serializers.serializers_utils import get_relation_links from stac_api.utils import call_calculate_extent diff --git a/app/tests/tests_09/test_serializer.py b/app/tests/tests_09/test_serializer.py index 9953e2e9..66374304 100644 --- a/app/tests/tests_09/test_serializer.py +++ b/app/tests/tests_09/test_serializer.py @@ -14,8 +14,8 @@ from stac_api.models import get_asset_path from stac_api.serializers.collection import CollectionSerializer -from stac_api.serializers.serializers import AssetSerializer -from stac_api.serializers.serializers import ItemSerializer +from stac_api.serializers.item import AssetSerializer +from stac_api.serializers.item import ItemSerializer from stac_api.utils import get_link from stac_api.utils import isoformat from stac_api.utils import utc_aware diff --git a/app/tests/tests_10/test_serializer.py b/app/tests/tests_10/test_serializer.py index 30c8a889..79952a5c 100644 --- a/app/tests/tests_10/test_serializer.py +++ b/app/tests/tests_10/test_serializer.py @@ -14,8 +14,8 @@ from stac_api.models import get_asset_path from stac_api.serializers.collection import CollectionSerializer -from stac_api.serializers.serializers import AssetSerializer -from stac_api.serializers.serializers import ItemSerializer +from stac_api.serializers.item import AssetSerializer +from stac_api.serializers.item import ItemSerializer from stac_api.utils import get_link from stac_api.utils import isoformat from stac_api.utils import utc_aware From 57a0dcc0fff9f03befffab5b156541b2d0172e51 Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Wed, 11 Sep 2024 13:41:15 +0200 Subject: [PATCH 12/16] Move upload views to own file --- app/stac_api/urls.py | 20 +- app/stac_api/views/upload.py | 445 +++++++++++++++++++++++++++++++++++ app/stac_api/views/views.py | 433 +--------------------------------- 3 files changed, 456 insertions(+), 442 deletions(-) create mode 100644 app/stac_api/views/upload.py diff --git a/app/stac_api/urls.py b/app/stac_api/urls.py index 95b2ba46..d3c28032 100644 --- a/app/stac_api/urls.py +++ b/app/stac_api/urls.py @@ -12,16 +12,16 @@ from stac_api.views.item import AssetsList from stac_api.views.item import ItemDetail from stac_api.views.item import ItemsList -from stac_api.views.views import AssetUploadAbort -from stac_api.views.views import AssetUploadComplete -from stac_api.views.views import AssetUploadDetail -from stac_api.views.views import AssetUploadPartsList -from stac_api.views.views import AssetUploadsList -from stac_api.views.views import CollectionAssetUploadAbort -from stac_api.views.views import CollectionAssetUploadComplete -from stac_api.views.views import CollectionAssetUploadDetail -from stac_api.views.views import CollectionAssetUploadPartsList -from stac_api.views.views import CollectionAssetUploadsList +from stac_api.views.upload import AssetUploadAbort +from stac_api.views.upload import AssetUploadComplete +from stac_api.views.upload import AssetUploadDetail +from stac_api.views.upload import AssetUploadPartsList +from stac_api.views.upload import AssetUploadsList +from stac_api.views.upload import CollectionAssetUploadAbort +from stac_api.views.upload import CollectionAssetUploadComplete +from stac_api.views.upload import CollectionAssetUploadDetail +from stac_api.views.upload import CollectionAssetUploadPartsList +from stac_api.views.upload import CollectionAssetUploadsList from stac_api.views.views import ConformancePageDetail from stac_api.views.views import LandingPageDetail from stac_api.views.views import SearchList diff --git a/app/stac_api/views/upload.py b/app/stac_api/views/upload.py new file mode 100644 index 00000000..ec8c9920 --- /dev/null +++ b/app/stac_api/views/upload.py @@ -0,0 +1,445 @@ +import logging +from datetime import datetime +from operator import itemgetter + +from django.db import IntegrityError +from django.db import transaction +from django.utils.translation import gettext_lazy as _ + +from rest_framework import generics +from rest_framework import mixins +from rest_framework import serializers +from rest_framework.exceptions import APIException +from rest_framework.generics import get_object_or_404 +from rest_framework.response import Response +from rest_framework_condition import etag + +from stac_api.exceptions import UploadInProgressError +from stac_api.exceptions import UploadNotInProgressError +from stac_api.models import Asset +from stac_api.models import AssetUpload +from stac_api.models import BaseAssetUpload +from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload +from stac_api.pagination import ExtApiPagination +from stac_api.s3_multipart_upload import MultipartUpload +from stac_api.serializers.serializers import AssetUploadPartsSerializer +from stac_api.serializers.serializers import AssetUploadSerializer +from stac_api.serializers.serializers import CollectionAssetUploadSerializer +from stac_api.utils import get_asset_path +from stac_api.utils import get_collection_asset_path +from stac_api.utils import select_s3_bucket +from stac_api.utils import utc_aware +from stac_api.validators_view import validate_asset +from stac_api.validators_view import validate_collection_asset +from stac_api.views import views_mixins +from stac_api.views.views import get_etag + +logger = logging.getLogger(__name__) + + +def get_asset_upload_etag(request, *args, **kwargs): + '''Get the ETag for an asset upload object + + The ETag is an UUID4 computed on each object changes + ''' + return get_etag( + AssetUpload.objects.filter( + asset__item__collection__name=kwargs['collection_name'], + asset__item__name=kwargs['item_name'], + asset__name=kwargs['asset_name'], + upload_id=kwargs['upload_id'] + ) + ) + + +def get_collection_asset_upload_etag(request, *args, **kwargs): + '''Get the ETag for a collection asset upload object + + The ETag is an UUID4 computed on each object changes + ''' + return get_etag( + CollectionAssetUpload.objects.filter( + asset__collection__name=kwargs['collection_name'], + asset__name=kwargs['asset_name'], + upload_id=kwargs['upload_id'] + ) + ) + + +class SharedAssetUploadBase(generics.GenericAPIView): + """SharedAssetUploadBase provides a base view for asset uploads and collection asset uploads. + """ + lookup_url_kwarg = "upload_id" + lookup_field = "upload_id" + + def get_queryset(self): + raise NotImplementedError("get_queryset() not implemented") + + def get_in_progress_queryset(self): + return self.get_queryset().filter(status=BaseAssetUpload.Status.IN_PROGRESS) + + def get_asset_or_404(self): + raise NotImplementedError("get_asset_or_404() not implemented") + + def log_extra(self, asset): + if isinstance(asset, CollectionAsset): + return {'collection': asset.collection.name, 'asset': asset.name} + return { + 'collection': asset.item.collection.name, 'item': asset.item.name, 'asset': asset.name + } + + def get_path(self, asset): + if isinstance(asset, CollectionAsset): + return get_collection_asset_path(asset.collection, asset.name) + return get_asset_path(asset.item, asset.name) + + def _save_asset_upload(self, executor, serializer, key, asset, upload_id, urls): + try: + with transaction.atomic(): + serializer.save(asset=asset, upload_id=upload_id, urls=urls) + except IntegrityError as error: + logger.error( + 'Failed to create asset upload multipart: %s', error, extra=self.log_extra(asset) + ) + if bool(self.get_in_progress_queryset()): + raise UploadInProgressError( + data={"upload_id": self.get_in_progress_queryset()[0].upload_id} + ) from None + raise + + def create_multipart_upload(self, executor, serializer, validated_data, asset): + key = self.get_path(asset) + + upload_id = executor.create_multipart_upload( + key, + asset, + validated_data['checksum_multihash'], + validated_data['update_interval'], + validated_data['content_encoding'] + ) + urls = [] + sorted_md5_parts = sorted(validated_data['md5_parts'], key=itemgetter('part_number')) + + try: + for part in sorted_md5_parts: + urls.append( + executor.create_presigned_url( + key, asset, part['part_number'], upload_id, part['md5'] + ) + ) + + self._save_asset_upload(executor, serializer, key, asset, upload_id, urls) + except APIException as err: + executor.abort_multipart_upload(key, asset, upload_id) + raise + + def complete_multipart_upload(self, executor, validated_data, asset_upload, asset): + key = self.get_path(asset) + parts = validated_data.get('parts', None) + if parts is None: + raise serializers.ValidationError({ + 'parts': _("Missing required field") + }, code='missing') + if len(parts) > asset_upload.number_parts: + raise serializers.ValidationError({'parts': [_("Too many parts")]}, code='invalid') + if len(parts) < asset_upload.number_parts: + raise serializers.ValidationError({'parts': [_("Too few parts")]}, code='invalid') + if asset_upload.status != BaseAssetUpload.Status.IN_PROGRESS: + raise UploadNotInProgressError() + executor.complete_multipart_upload(key, asset, parts, asset_upload.upload_id) + asset_upload.update_asset_from_upload() + asset_upload.status = BaseAssetUpload.Status.COMPLETED + asset_upload.ended = utc_aware(datetime.utcnow()) + asset_upload.urls = [] + asset_upload.save() + + def abort_multipart_upload(self, executor, asset_upload, asset): + key = self.get_path(asset) + executor.abort_multipart_upload(key, asset, asset_upload.upload_id) + asset_upload.status = BaseAssetUpload.Status.ABORTED + asset_upload.ended = utc_aware(datetime.utcnow()) + asset_upload.urls = [] + asset_upload.save() + + def list_multipart_upload_parts(self, executor, asset_upload, asset, limit, offset): + key = self.get_path(asset) + return executor.list_upload_parts(key, asset, asset_upload.upload_id, limit, offset) + + +class AssetUploadBase(SharedAssetUploadBase): + """AssetUploadBase is the base for all asset (not collection asset) upload views. + """ + serializer_class = AssetUploadSerializer + + def get_queryset(self): + return AssetUpload.objects.filter( + asset__item__collection__name=self.kwargs['collection_name'], + asset__item__name=self.kwargs['item_name'], + asset__name=self.kwargs['asset_name'] + ).prefetch_related('asset') + + def get_asset_or_404(self): + return get_object_or_404( + Asset.objects.all(), + name=self.kwargs['asset_name'], + item__name=self.kwargs['item_name'], + item__collection__name=self.kwargs['collection_name'] + ) + + +class AssetUploadsList(AssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin): + + class ExternalDisallowedException(Exception): + pass + + def post(self, request, *args, **kwargs): + try: + return self.create(request, *args, **kwargs) + except self.ExternalDisallowedException as ex: + data = { + "code": 400, + "description": "Not allowed to create multipart uploads on external assets" + } + return Response(status=400, exception=True, data=data) + + def get(self, request, *args, **kwargs): + validate_asset(self.kwargs) + return self.list(request, *args, **kwargs) + + def get_success_headers(self, data): + return {'Location': '/'.join([self.request.build_absolute_uri(), data['upload_id']])} + + def perform_create(self, serializer): + data = serializer.validated_data + asset = self.get_asset_or_404() + collection = asset.item.collection + + if asset.is_external: + raise self.ExternalDisallowedException() + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.create_multipart_upload(executor, serializer, data, asset) + + def get_queryset(self): + queryset = super().get_queryset() + + status = self.request.query_params.get('status', None) + if status: + queryset = queryset.filter_by_status(status) + + return queryset + + +class AssetUploadDetail(AssetUploadBase, mixins.RetrieveModelMixin, views_mixins.DestroyModelMixin): + + @etag(get_asset_upload_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class AssetUploadComplete(AssetUploadBase, views_mixins.UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.item.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.complete_multipart_upload( + executor, serializer.validated_data, serializer.instance, asset + ) + + +class AssetUploadAbort(AssetUploadBase, views_mixins.UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.item.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + self.abort_multipart_upload(executor, serializer.instance, asset) + + +class AssetUploadPartsList(AssetUploadBase): + serializer_class = AssetUploadPartsSerializer + pagination_class = ExtApiPagination + + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def list(self, request, *args, **kwargs): + asset_upload = self.get_object() + limit, offset = self.get_pagination_config(request) + + collection = asset_upload.asset.item.collection + s3_bucket = select_s3_bucket(collection.name) + + executor = MultipartUpload(s3_bucket) + + data, has_next = self.list_multipart_upload_parts( + executor, asset_upload, asset_upload.asset, limit, offset + ) + serializer = self.get_serializer(data) + + return self.get_paginated_response(serializer.data, has_next) + + def get_pagination_config(self, request): + return self.paginator.get_pagination_config(request) + + def get_paginated_response(self, data, has_next): # pylint: disable=arguments-differ + return self.paginator.get_paginated_response(data, has_next) + + +class CollectionAssetUploadBase(SharedAssetUploadBase): + """CollectionAssetUploadBase is the base for all collection asset upload views. + """ + serializer_class = CollectionAssetUploadSerializer + + def get_queryset(self): + return CollectionAssetUpload.objects.filter( + asset__collection__name=self.kwargs['collection_name'], + asset__name=self.kwargs['asset_name'] + ).prefetch_related('asset') + + def get_asset_or_404(self): + return get_object_or_404( + CollectionAsset.objects.all(), + name=self.kwargs['asset_name'], + collection__name=self.kwargs['collection_name'] + ) + + +class CollectionAssetUploadsList( + CollectionAssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin +): + + class ExternalDisallowedException(Exception): + pass + + def post(self, request, *args, **kwargs): + try: + return self.create(request, *args, **kwargs) + except self.ExternalDisallowedException as ex: + data = { + "code": 400, + "description": "Not allowed to create multipart uploads on external assets" + } + return Response(status=400, exception=True, data=data) + + def get(self, request, *args, **kwargs): + validate_collection_asset(self.kwargs) + return self.list(request, *args, **kwargs) + + def get_success_headers(self, data): + return {'Location': '/'.join([self.request.build_absolute_uri(), data['upload_id']])} + + def perform_create(self, serializer): + data = serializer.validated_data + asset = self.get_asset_or_404() + collection = asset.collection + + if asset.is_external: + raise self.ExternalDisallowedException() + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.create_multipart_upload(executor, serializer, data, asset) + + def get_queryset(self): + queryset = super().get_queryset() + + status = self.request.query_params.get('status', None) + if status: + queryset = queryset.filter_by_status(status) + + return queryset + + +class CollectionAssetUploadDetail( + CollectionAssetUploadBase, mixins.RetrieveModelMixin, views_mixins.DestroyModelMixin +): + + @etag(get_collection_asset_upload_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class CollectionAssetUploadComplete(CollectionAssetUploadBase, views_mixins.UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.complete_multipart_upload( + executor, serializer.validated_data, serializer.instance, asset + ) + + +class CollectionAssetUploadAbort(CollectionAssetUploadBase, views_mixins.UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + self.abort_multipart_upload(executor, serializer.instance, asset) + + +class CollectionAssetUploadPartsList(CollectionAssetUploadBase): + serializer_class = AssetUploadPartsSerializer + pagination_class = ExtApiPagination + + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def list(self, request, *args, **kwargs): + asset_upload = self.get_object() + limit, offset = self.get_pagination_config(request) + + collection = asset_upload.asset.collection + s3_bucket = select_s3_bucket(collection.name) + + executor = MultipartUpload(s3_bucket) + + data, has_next = self.list_multipart_upload_parts( + executor, asset_upload, asset_upload.asset, limit, offset + ) + serializer = self.get_serializer(data) + + return self.get_paginated_response(serializer.data, has_next) + + def get_pagination_config(self, request): + return self.paginator.get_pagination_config(request) + + def get_paginated_response(self, data, has_next): # pylint: disable=arguments-differ + return self.paginator.get_paginated_response(data, has_next) diff --git a/app/stac_api/views/views.py b/app/stac_api/views/views.py index bda80893..03e88f46 100644 --- a/app/stac_api/views/views.py +++ b/app/stac_api/views/views.py @@ -1,55 +1,31 @@ import json import logging from datetime import datetime -from operator import itemgetter from django.conf import settings -from django.db import IntegrityError -from django.db import transaction from django.db.models import Min from django.utils.translation import gettext_lazy as _ from rest_framework import generics from rest_framework import mixins from rest_framework import permissions -from rest_framework import serializers from rest_framework.decorators import api_view from rest_framework.decorators import permission_classes -from rest_framework.exceptions import APIException -from rest_framework.generics import get_object_or_404 from rest_framework.permissions import AllowAny from rest_framework.response import Response -from rest_framework_condition import etag - -from stac_api.exceptions import UploadInProgressError -from stac_api.exceptions import UploadNotInProgressError -from stac_api.models import Asset -from stac_api.models import AssetUpload -from stac_api.models import BaseAssetUpload -from stac_api.models import CollectionAsset -from stac_api.models import CollectionAssetUpload + from stac_api.models import Item from stac_api.models import LandingPage -from stac_api.pagination import ExtApiPagination from stac_api.pagination import GetPostCursorPagination -from stac_api.s3_multipart_upload import MultipartUpload from stac_api.serializers.item import ItemSerializer -from stac_api.serializers.serializers import AssetUploadPartsSerializer -from stac_api.serializers.serializers import AssetUploadSerializer -from stac_api.serializers.serializers import CollectionAssetUploadSerializer from stac_api.serializers.serializers import ConformancePageSerializer from stac_api.serializers.serializers import LandingPageSerializer from stac_api.serializers.serializers_utils import get_relation_links from stac_api.utils import call_calculate_extent -from stac_api.utils import get_asset_path -from stac_api.utils import get_collection_asset_path from stac_api.utils import harmonize_post_get_for_search from stac_api.utils import is_api_version_1 -from stac_api.utils import select_s3_bucket from stac_api.utils import utc_aware from stac_api.validators_serializer import ValidateSearchRequest -from stac_api.validators_view import validate_asset -from stac_api.validators_view import validate_collection_asset from stac_api.views import views_mixins logger = logging.getLogger(__name__) @@ -61,35 +37,6 @@ def get_etag(queryset): return None -def get_asset_upload_etag(request, *args, **kwargs): - '''Get the ETag for an asset upload object - - The ETag is an UUID4 computed on each object changes - ''' - return get_etag( - AssetUpload.objects.filter( - asset__item__collection__name=kwargs['collection_name'], - asset__item__name=kwargs['item_name'], - asset__name=kwargs['asset_name'], - upload_id=kwargs['upload_id'] - ) - ) - - -def get_collection_asset_upload_etag(request, *args, **kwargs): - '''Get the ETag for a collection asset upload object - - The ETag is an UUID4 computed on each object changes - ''' - return get_etag( - CollectionAssetUpload.objects.filter( - asset__collection__name=kwargs['collection_name'], - asset__name=kwargs['asset_name'], - upload_id=kwargs['upload_id'] - ) - ) - - class LandingPageDetail(generics.RetrieveAPIView): name = 'landing-page' # this name must match the name in urls.py serializer_class = LandingPageSerializer @@ -212,381 +159,3 @@ def post(self, request, *args, **kwargs): def recalculate_extent(request): call_calculate_extent() return Response() - - -class SharedAssetUploadBase(generics.GenericAPIView): - """SharedAssetUploadBase provides a base view for asset uploads and collection asset uploads. - """ - lookup_url_kwarg = "upload_id" - lookup_field = "upload_id" - - def get_queryset(self): - raise NotImplementedError("get_queryset() not implemented") - - def get_in_progress_queryset(self): - return self.get_queryset().filter(status=BaseAssetUpload.Status.IN_PROGRESS) - - def get_asset_or_404(self): - raise NotImplementedError("get_asset_or_404() not implemented") - - def log_extra(self, asset): - if isinstance(asset, CollectionAsset): - return {'collection': asset.collection.name, 'asset': asset.name} - return { - 'collection': asset.item.collection.name, 'item': asset.item.name, 'asset': asset.name - } - - def get_path(self, asset): - if isinstance(asset, CollectionAsset): - return get_collection_asset_path(asset.collection, asset.name) - return get_asset_path(asset.item, asset.name) - - def _save_asset_upload(self, executor, serializer, key, asset, upload_id, urls): - try: - with transaction.atomic(): - serializer.save(asset=asset, upload_id=upload_id, urls=urls) - except IntegrityError as error: - logger.error( - 'Failed to create asset upload multipart: %s', error, extra=self.log_extra(asset) - ) - if bool(self.get_in_progress_queryset()): - raise UploadInProgressError( - data={"upload_id": self.get_in_progress_queryset()[0].upload_id} - ) from None - raise - - def create_multipart_upload(self, executor, serializer, validated_data, asset): - key = self.get_path(asset) - - upload_id = executor.create_multipart_upload( - key, - asset, - validated_data['checksum_multihash'], - validated_data['update_interval'], - validated_data['content_encoding'] - ) - urls = [] - sorted_md5_parts = sorted(validated_data['md5_parts'], key=itemgetter('part_number')) - - try: - for part in sorted_md5_parts: - urls.append( - executor.create_presigned_url( - key, asset, part['part_number'], upload_id, part['md5'] - ) - ) - - self._save_asset_upload(executor, serializer, key, asset, upload_id, urls) - except APIException as err: - executor.abort_multipart_upload(key, asset, upload_id) - raise - - def complete_multipart_upload(self, executor, validated_data, asset_upload, asset): - key = self.get_path(asset) - parts = validated_data.get('parts', None) - if parts is None: - raise serializers.ValidationError({ - 'parts': _("Missing required field") - }, code='missing') - if len(parts) > asset_upload.number_parts: - raise serializers.ValidationError({'parts': [_("Too many parts")]}, code='invalid') - if len(parts) < asset_upload.number_parts: - raise serializers.ValidationError({'parts': [_("Too few parts")]}, code='invalid') - if asset_upload.status != BaseAssetUpload.Status.IN_PROGRESS: - raise UploadNotInProgressError() - executor.complete_multipart_upload(key, asset, parts, asset_upload.upload_id) - asset_upload.update_asset_from_upload() - asset_upload.status = BaseAssetUpload.Status.COMPLETED - asset_upload.ended = utc_aware(datetime.utcnow()) - asset_upload.urls = [] - asset_upload.save() - - def abort_multipart_upload(self, executor, asset_upload, asset): - key = self.get_path(asset) - executor.abort_multipart_upload(key, asset, asset_upload.upload_id) - asset_upload.status = BaseAssetUpload.Status.ABORTED - asset_upload.ended = utc_aware(datetime.utcnow()) - asset_upload.urls = [] - asset_upload.save() - - def list_multipart_upload_parts(self, executor, asset_upload, asset, limit, offset): - key = self.get_path(asset) - return executor.list_upload_parts(key, asset, asset_upload.upload_id, limit, offset) - - -class AssetUploadBase(SharedAssetUploadBase): - """AssetUploadBase is the base for all asset (not collection asset) upload views. - """ - serializer_class = AssetUploadSerializer - - def get_queryset(self): - return AssetUpload.objects.filter( - asset__item__collection__name=self.kwargs['collection_name'], - asset__item__name=self.kwargs['item_name'], - asset__name=self.kwargs['asset_name'] - ).prefetch_related('asset') - - def get_asset_or_404(self): - return get_object_or_404( - Asset.objects.all(), - name=self.kwargs['asset_name'], - item__name=self.kwargs['item_name'], - item__collection__name=self.kwargs['collection_name'] - ) - - -class AssetUploadsList(AssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin): - - class ExternalDisallowedException(Exception): - pass - - def post(self, request, *args, **kwargs): - try: - return self.create(request, *args, **kwargs) - except self.ExternalDisallowedException as ex: - data = { - "code": 400, - "description": "Not allowed to create multipart uploads on external assets" - } - return Response(status=400, exception=True, data=data) - - def get(self, request, *args, **kwargs): - validate_asset(self.kwargs) - return self.list(request, *args, **kwargs) - - def get_success_headers(self, data): - return {'Location': '/'.join([self.request.build_absolute_uri(), data['upload_id']])} - - def perform_create(self, serializer): - data = serializer.validated_data - asset = self.get_asset_or_404() - collection = asset.item.collection - - if asset.is_external: - raise self.ExternalDisallowedException() - - s3_bucket = select_s3_bucket(collection.name) - executor = MultipartUpload(s3_bucket) - - self.create_multipart_upload(executor, serializer, data, asset) - - def get_queryset(self): - queryset = super().get_queryset() - - status = self.request.query_params.get('status', None) - if status: - queryset = queryset.filter_by_status(status) - - return queryset - - -class AssetUploadDetail(AssetUploadBase, mixins.RetrieveModelMixin, views_mixins.DestroyModelMixin): - - @etag(get_asset_upload_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - -class AssetUploadComplete(AssetUploadBase, views_mixins.UpdateInsertModelMixin): - - def post(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) - - def perform_update(self, serializer): - asset = serializer.instance.asset - - collection = asset.item.collection - - s3_bucket = select_s3_bucket(collection.name) - executor = MultipartUpload(s3_bucket) - - self.complete_multipart_upload( - executor, serializer.validated_data, serializer.instance, asset - ) - - -class AssetUploadAbort(AssetUploadBase, views_mixins.UpdateInsertModelMixin): - - def post(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) - - def perform_update(self, serializer): - asset = serializer.instance.asset - - collection = asset.item.collection - - s3_bucket = select_s3_bucket(collection.name) - executor = MultipartUpload(s3_bucket) - self.abort_multipart_upload(executor, serializer.instance, asset) - - -class AssetUploadPartsList(AssetUploadBase): - serializer_class = AssetUploadPartsSerializer - pagination_class = ExtApiPagination - - def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) - - def list(self, request, *args, **kwargs): - asset_upload = self.get_object() - limit, offset = self.get_pagination_config(request) - - collection = asset_upload.asset.item.collection - s3_bucket = select_s3_bucket(collection.name) - - executor = MultipartUpload(s3_bucket) - - data, has_next = self.list_multipart_upload_parts( - executor, asset_upload, asset_upload.asset, limit, offset - ) - serializer = self.get_serializer(data) - - return self.get_paginated_response(serializer.data, has_next) - - def get_pagination_config(self, request): - return self.paginator.get_pagination_config(request) - - def get_paginated_response(self, data, has_next): # pylint: disable=arguments-differ - return self.paginator.get_paginated_response(data, has_next) - - -class CollectionAssetUploadBase(SharedAssetUploadBase): - """CollectionAssetUploadBase is the base for all collection asset upload views. - """ - serializer_class = CollectionAssetUploadSerializer - - def get_queryset(self): - return CollectionAssetUpload.objects.filter( - asset__collection__name=self.kwargs['collection_name'], - asset__name=self.kwargs['asset_name'] - ).prefetch_related('asset') - - def get_asset_or_404(self): - return get_object_or_404( - CollectionAsset.objects.all(), - name=self.kwargs['asset_name'], - collection__name=self.kwargs['collection_name'] - ) - - -class CollectionAssetUploadsList( - CollectionAssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin -): - - class ExternalDisallowedException(Exception): - pass - - def post(self, request, *args, **kwargs): - try: - return self.create(request, *args, **kwargs) - except self.ExternalDisallowedException as ex: - data = { - "code": 400, - "description": "Not allowed to create multipart uploads on external assets" - } - return Response(status=400, exception=True, data=data) - - def get(self, request, *args, **kwargs): - validate_collection_asset(self.kwargs) - return self.list(request, *args, **kwargs) - - def get_success_headers(self, data): - return {'Location': '/'.join([self.request.build_absolute_uri(), data['upload_id']])} - - def perform_create(self, serializer): - data = serializer.validated_data - asset = self.get_asset_or_404() - collection = asset.collection - - if asset.is_external: - raise self.ExternalDisallowedException() - - s3_bucket = select_s3_bucket(collection.name) - executor = MultipartUpload(s3_bucket) - - self.create_multipart_upload(executor, serializer, data, asset) - - def get_queryset(self): - queryset = super().get_queryset() - - status = self.request.query_params.get('status', None) - if status: - queryset = queryset.filter_by_status(status) - - return queryset - - -class CollectionAssetUploadDetail( - CollectionAssetUploadBase, mixins.RetrieveModelMixin, views_mixins.DestroyModelMixin -): - - @etag(get_collection_asset_upload_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - -class CollectionAssetUploadComplete(CollectionAssetUploadBase, views_mixins.UpdateInsertModelMixin): - - def post(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) - - def perform_update(self, serializer): - asset = serializer.instance.asset - - collection = asset.collection - - s3_bucket = select_s3_bucket(collection.name) - executor = MultipartUpload(s3_bucket) - - self.complete_multipart_upload( - executor, serializer.validated_data, serializer.instance, asset - ) - - -class CollectionAssetUploadAbort(CollectionAssetUploadBase, views_mixins.UpdateInsertModelMixin): - - def post(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) - - def perform_update(self, serializer): - asset = serializer.instance.asset - - collection = asset.collection - - s3_bucket = select_s3_bucket(collection.name) - executor = MultipartUpload(s3_bucket) - self.abort_multipart_upload(executor, serializer.instance, asset) - - -class CollectionAssetUploadPartsList(CollectionAssetUploadBase): - serializer_class = AssetUploadPartsSerializer - pagination_class = ExtApiPagination - - def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) - - def list(self, request, *args, **kwargs): - asset_upload = self.get_object() - limit, offset = self.get_pagination_config(request) - - collection = asset_upload.asset.collection - s3_bucket = select_s3_bucket(collection.name) - - executor = MultipartUpload(s3_bucket) - - data, has_next = self.list_multipart_upload_parts( - executor, asset_upload, asset_upload.asset, limit, offset - ) - serializer = self.get_serializer(data) - - return self.get_paginated_response(serializer.data, has_next) - - def get_pagination_config(self, request): - return self.paginator.get_pagination_config(request) - - def get_paginated_response(self, data, has_next): # pylint: disable=arguments-differ - return self.paginator.get_paginated_response(data, has_next) From f436606f9d12ad09221671f4872d7fd39b19e38c Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Wed, 11 Sep 2024 13:52:23 +0200 Subject: [PATCH 13/16] Rename view files Files should not have the same name as their folder. --- app/config/urls.py | 10 ++++----- app/stac_api/urls.py | 8 +++---- app/stac_api/views/collection.py | 19 +++++++--------- app/stac_api/views/{views.py => general.py} | 4 ++-- app/stac_api/views/item.py | 20 ++++++++--------- .../views/{views_mixins.py => mixins.py} | 0 app/stac_api/views/{views_test.py => test.py} | 0 app/stac_api/views/upload.py | 22 ++++++++++--------- 8 files changed, 41 insertions(+), 42 deletions(-) rename app/stac_api/views/{views.py => general.py} (97%) rename app/stac_api/views/{views_mixins.py => mixins.py} (100%) rename app/stac_api/views/{views_test.py => test.py} (100%) diff --git a/app/config/urls.py b/app/config/urls.py index 0f481c90..f738e8f9 100644 --- a/app/config/urls.py +++ b/app/config/urls.py @@ -38,11 +38,11 @@ def checker(request): if settings.DEBUG: import debug_toolbar - from stac_api.views.views_test import TestAssetUpsertHttp500 - from stac_api.views.views_test import TestCollectionAssetUpsertHttp500 - from stac_api.views.views_test import TestCollectionUpsertHttp500 - from stac_api.views.views_test import TestHttp500 - from stac_api.views.views_test import TestItemUpsertHttp500 + from stac_api.views.test import TestAssetUpsertHttp500 + from stac_api.views.test import TestCollectionAssetUpsertHttp500 + from stac_api.views.test import TestCollectionUpsertHttp500 + from stac_api.views.test import TestHttp500 + from stac_api.views.test import TestItemUpsertHttp500 urlpatterns = [ path('__debug__/', include(debug_toolbar.urls)), diff --git a/app/stac_api/urls.py b/app/stac_api/urls.py index d3c28032..5ccaf6cd 100644 --- a/app/stac_api/urls.py +++ b/app/stac_api/urls.py @@ -8,6 +8,10 @@ from stac_api.views.collection import CollectionAssetsList from stac_api.views.collection import CollectionDetail from stac_api.views.collection import CollectionList +from stac_api.views.general import ConformancePageDetail +from stac_api.views.general import LandingPageDetail +from stac_api.views.general import SearchList +from stac_api.views.general import recalculate_extent from stac_api.views.item import AssetDetail from stac_api.views.item import AssetsList from stac_api.views.item import ItemDetail @@ -22,10 +26,6 @@ from stac_api.views.upload import CollectionAssetUploadDetail from stac_api.views.upload import CollectionAssetUploadPartsList from stac_api.views.upload import CollectionAssetUploadsList -from stac_api.views.views import ConformancePageDetail -from stac_api.views.views import LandingPageDetail -from stac_api.views.views import SearchList -from stac_api.views.views import recalculate_extent # HEALTHCHECK_ENDPOINT = settings.HEALTHCHECK_ENDPOINT diff --git a/app/stac_api/views/collection.py b/app/stac_api/views/collection.py index 2986300e..99cee675 100644 --- a/app/stac_api/views/collection.py +++ b/app/stac_api/views/collection.py @@ -16,8 +16,11 @@ from stac_api.utils import get_collection_asset_path from stac_api.validators_view import validate_collection from stac_api.validators_view import validate_renaming -from stac_api.views import views_mixins -from stac_api.views.views import get_etag +from stac_api.views.general import get_etag +from stac_api.views.mixins import DestroyModelMixin +from stac_api.views.mixins import RetrieveModelDynCacheMixin +from stac_api.views.mixins import UpdateInsertModelMixin +from stac_api.views.mixins import patch_cache_settings_by_update_interval logger = logging.getLogger(__name__) @@ -93,10 +96,7 @@ def get(self, request, *args, **kwargs): class CollectionDetail( - generics.GenericAPIView, - mixins.RetrieveModelMixin, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin + generics.GenericAPIView, mixins.RetrieveModelMixin, UpdateInsertModelMixin, DestroyModelMixin ): # this name must match the name in urls.py and is used by the DestroyModelMixin name = 'collection-detail' @@ -173,15 +173,12 @@ def get(self, request, *args, **kwargs): 'links': get_relation_links(request, self.name, [self.kwargs['collection_name']]) } response = Response(data) - views_mixins.patch_cache_settings_by_update_interval(response, update_interval) + patch_cache_settings_by_update_interval(response, update_interval) return response class CollectionAssetDetail( - generics.GenericAPIView, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin, - views_mixins.RetrieveModelDynCacheMixin + generics.GenericAPIView, UpdateInsertModelMixin, DestroyModelMixin, RetrieveModelDynCacheMixin ): # this name must match the name in urls.py and is used by the DestroyModelMixin name = 'collection-asset-detail' diff --git a/app/stac_api/views/views.py b/app/stac_api/views/general.py similarity index 97% rename from app/stac_api/views/views.py rename to app/stac_api/views/general.py index 03e88f46..566453d0 100644 --- a/app/stac_api/views/views.py +++ b/app/stac_api/views/general.py @@ -26,7 +26,7 @@ from stac_api.utils import is_api_version_1 from stac_api.utils import utc_aware from stac_api.validators_serializer import ValidateSearchRequest -from stac_api.views import views_mixins +from stac_api.views.mixins import patch_cache_settings_by_update_interval logger = logging.getLogger(__name__) @@ -146,7 +146,7 @@ def list(self, request, *args, **kwargs): def get(self, request, *args, **kwargs): response, min_update_interval = self.list(request, *args, **kwargs) - views_mixins.patch_cache_settings_by_update_interval(response, min_update_interval) + patch_cache_settings_by_update_interval(response, min_update_interval) return response def post(self, request, *args, **kwargs): diff --git a/app/stac_api/views/item.py b/app/stac_api/views/item.py index 6d9469b6..87e05f56 100644 --- a/app/stac_api/views/item.py +++ b/app/stac_api/views/item.py @@ -22,8 +22,8 @@ from stac_api.validators_view import validate_collection from stac_api.validators_view import validate_item from stac_api.validators_view import validate_renaming -from stac_api.views import views_mixins -from stac_api.views.views import get_etag +from stac_api.views import mixins +from stac_api.views.general import get_etag logger = logging.getLogger(__name__) @@ -133,7 +133,7 @@ def list(self, request, *args, **kwargs): if page is not None: response = self.get_paginated_response(data) response = Response(data) - views_mixins.patch_cache_settings_by_update_interval(response, update_interval) + mixins.patch_cache_settings_by_update_interval(response, update_interval) return response def get(self, request, *args, **kwargs): @@ -142,9 +142,9 @@ def get(self, request, *args, **kwargs): class ItemDetail( generics.GenericAPIView, - views_mixins.RetrieveModelDynCacheMixin, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin + mixins.RetrieveModelDynCacheMixin, + mixins.UpdateInsertModelMixin, + mixins.DestroyModelMixin ): # this name must match the name in urls.py and is used by the DestroyModelMixin name = 'item-detail' @@ -246,15 +246,15 @@ def get(self, request, *args, **kwargs): ) } response = Response(data) - views_mixins.patch_cache_settings_by_update_interval(response, update_interval) + mixins.patch_cache_settings_by_update_interval(response, update_interval) return response class AssetDetail( generics.GenericAPIView, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin, - views_mixins.RetrieveModelDynCacheMixin + mixins.UpdateInsertModelMixin, + mixins.DestroyModelMixin, + mixins.RetrieveModelDynCacheMixin ): # this name must match the name in urls.py and is used by the DestroyModelMixin name = 'asset-detail' diff --git a/app/stac_api/views/views_mixins.py b/app/stac_api/views/mixins.py similarity index 100% rename from app/stac_api/views/views_mixins.py rename to app/stac_api/views/mixins.py diff --git a/app/stac_api/views/views_test.py b/app/stac_api/views/test.py similarity index 100% rename from app/stac_api/views/views_test.py rename to app/stac_api/views/test.py diff --git a/app/stac_api/views/upload.py b/app/stac_api/views/upload.py index ec8c9920..20580cf8 100644 --- a/app/stac_api/views/upload.py +++ b/app/stac_api/views/upload.py @@ -32,8 +32,10 @@ from stac_api.utils import utc_aware from stac_api.validators_view import validate_asset from stac_api.validators_view import validate_collection_asset -from stac_api.views import views_mixins -from stac_api.views.views import get_etag +from stac_api.views.general import get_etag +from stac_api.views.mixins import CreateModelMixin +from stac_api.views.mixins import DestroyModelMixin +from stac_api.views.mixins import UpdateInsertModelMixin logger = logging.getLogger(__name__) @@ -188,7 +190,7 @@ def get_asset_or_404(self): ) -class AssetUploadsList(AssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin): +class AssetUploadsList(AssetUploadBase, mixins.ListModelMixin, CreateModelMixin): class ExternalDisallowedException(Exception): pass @@ -233,14 +235,14 @@ def get_queryset(self): return queryset -class AssetUploadDetail(AssetUploadBase, mixins.RetrieveModelMixin, views_mixins.DestroyModelMixin): +class AssetUploadDetail(AssetUploadBase, mixins.RetrieveModelMixin, DestroyModelMixin): @etag(get_asset_upload_etag) def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) -class AssetUploadComplete(AssetUploadBase, views_mixins.UpdateInsertModelMixin): +class AssetUploadComplete(AssetUploadBase, UpdateInsertModelMixin): def post(self, request, *args, **kwargs): kwargs['partial'] = True @@ -259,7 +261,7 @@ def perform_update(self, serializer): ) -class AssetUploadAbort(AssetUploadBase, views_mixins.UpdateInsertModelMixin): +class AssetUploadAbort(AssetUploadBase, UpdateInsertModelMixin): def post(self, request, *args, **kwargs): kwargs['partial'] = True @@ -325,7 +327,7 @@ def get_asset_or_404(self): class CollectionAssetUploadsList( - CollectionAssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin + CollectionAssetUploadBase, mixins.ListModelMixin, CreateModelMixin ): class ExternalDisallowedException(Exception): @@ -372,7 +374,7 @@ def get_queryset(self): class CollectionAssetUploadDetail( - CollectionAssetUploadBase, mixins.RetrieveModelMixin, views_mixins.DestroyModelMixin + CollectionAssetUploadBase, mixins.RetrieveModelMixin, DestroyModelMixin ): @etag(get_collection_asset_upload_etag) @@ -380,7 +382,7 @@ def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) -class CollectionAssetUploadComplete(CollectionAssetUploadBase, views_mixins.UpdateInsertModelMixin): +class CollectionAssetUploadComplete(CollectionAssetUploadBase, UpdateInsertModelMixin): def post(self, request, *args, **kwargs): kwargs['partial'] = True @@ -399,7 +401,7 @@ def perform_update(self, serializer): ) -class CollectionAssetUploadAbort(CollectionAssetUploadBase, views_mixins.UpdateInsertModelMixin): +class CollectionAssetUploadAbort(CollectionAssetUploadBase, UpdateInsertModelMixin): def post(self, request, *args, **kwargs): kwargs['partial'] = True From 64c9bf8a87f18828dd8a199c4cb36d774f10769c Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Wed, 11 Sep 2024 14:05:35 +0200 Subject: [PATCH 14/16] Move upload serializers to own file --- .../management/commands/list_asset_uploads.py | 2 +- app/stac_api/serializers/collection.py | 4 +- app/stac_api/serializers/item.py | 4 +- app/stac_api/serializers/serializers.py | 251 ------------------ app/stac_api/serializers/serializers_utils.py | 28 ++ app/stac_api/serializers/upload.py | 231 ++++++++++++++++ app/stac_api/views/upload.py | 6 +- .../tests_09/test_serializer_asset_upload.py | 2 +- .../tests_10/test_serializer_asset_upload.py | 2 +- 9 files changed, 269 insertions(+), 261 deletions(-) create mode 100644 app/stac_api/serializers/upload.py diff --git a/app/stac_api/management/commands/list_asset_uploads.py b/app/stac_api/management/commands/list_asset_uploads.py index ff4d1b83..079ac3ec 100644 --- a/app/stac_api/management/commands/list_asset_uploads.py +++ b/app/stac_api/management/commands/list_asset_uploads.py @@ -6,7 +6,7 @@ from stac_api.models import AssetUpload from stac_api.s3_multipart_upload import MultipartUpload -from stac_api.serializers.serializers import AssetUploadSerializer +from stac_api.serializers.upload import AssetUploadSerializer from stac_api.utils import CommandHandler from stac_api.utils import get_asset_path diff --git a/app/stac_api/serializers/collection.py b/app/stac_api/serializers/collection.py index 1968e7fb..b01cd722 100644 --- a/app/stac_api/serializers/collection.py +++ b/app/stac_api/serializers/collection.py @@ -8,8 +8,8 @@ from stac_api.models import CollectionAsset from stac_api.models import CollectionLink from stac_api.models import Provider -from stac_api.serializers.serializers import AssetsDictSerializer -from stac_api.serializers.serializers import HrefField +from stac_api.serializers.serializers_utils import AssetsDictSerializer +from stac_api.serializers.serializers_utils import HrefField from stac_api.serializers.serializers_utils import NonNullModelSerializer from stac_api.serializers.serializers_utils import UpsertModelSerializerMixin from stac_api.serializers.serializers_utils import get_relation_links diff --git a/app/stac_api/serializers/item.py b/app/stac_api/serializers/item.py index 7c545855..eb067475 100644 --- a/app/stac_api/serializers/item.py +++ b/app/stac_api/serializers/item.py @@ -9,8 +9,8 @@ from stac_api.models import Asset from stac_api.models import Item from stac_api.models import ItemLink -from stac_api.serializers.serializers import AssetsDictSerializer -from stac_api.serializers.serializers import HrefField +from stac_api.serializers.serializers_utils import AssetsDictSerializer +from stac_api.serializers.serializers_utils import HrefField from stac_api.serializers.serializers_utils import NonNullModelSerializer from stac_api.serializers.serializers_utils import UpsertModelSerializerMixin from stac_api.serializers.serializers_utils import get_relation_links diff --git a/app/stac_api/serializers/serializers.py b/app/stac_api/serializers/serializers.py index ecb87a3e..4d1aba7e 100644 --- a/app/stac_api/serializers/serializers.py +++ b/app/stac_api/serializers/serializers.py @@ -6,24 +6,14 @@ from django.utils.translation import gettext_lazy as _ from rest_framework import serializers -from rest_framework.utils.serializer_helpers import ReturnDict from rest_framework.validators import UniqueValidator -from stac_api.models import AssetUpload -from stac_api.models import CollectionAssetUpload from stac_api.models import LandingPage from stac_api.models import LandingPageLink -from stac_api.serializers.serializers_utils import DictSerializer -from stac_api.serializers.serializers_utils import NonNullModelSerializer -from stac_api.utils import build_asset_href from stac_api.utils import get_browser_url from stac_api.utils import get_stac_version from stac_api.utils import get_url from stac_api.utils import is_api_version_1 -from stac_api.utils import isoformat -from stac_api.validators import validate_checksum_multihash_sha256 -from stac_api.validators import validate_content_encoding -from stac_api.validators import validate_md5_parts from stac_api.validators import validate_name logger = logging.getLogger(__name__) @@ -138,244 +128,3 @@ def to_representation(self, instance): ]), ] return representation - - -class AssetsDictSerializer(DictSerializer): - '''Assets serializer list to dictionary - - This serializer returns an asset dictionary with the asset name as keys. - ''' - # pylint: disable=abstract-method - key_identifier = 'id' - - -class HrefField(serializers.Field): - '''Special Href field for Assets''' - - # pylint: disable=abstract-method - - def to_representation(self, value): - # build an absolute URL from the file path - request = self.context.get("request") - path = value.name - - if value.instance.is_external: - return path - return build_asset_href(request, path) - - def to_internal_value(self, data): - return data - - -class AssetUploadListSerializer(serializers.ListSerializer): - # pylint: disable=abstract-method - - def to_representation(self, data): - return {'uploads': super().to_representation(data)} - - @property - def data(self): - ret = super(serializers.ListSerializer, self).data - return ReturnDict(ret, serializer=self) - - -class UploadPartSerializer(serializers.Serializer): - '''This serializer is used to serialize the data from/to the S3 API. - ''' - # pylint: disable=abstract-method - etag = serializers.CharField(source='ETag', allow_blank=False, required=True) - part_number = serializers.IntegerField( - source='PartNumber', min_value=1, max_value=100, required=True, allow_null=False - ) - modified = serializers.DateTimeField(source='LastModified', required=False, allow_null=True) - size = serializers.IntegerField(source='Size', allow_null=True, required=False) - - -class AssetUploadSerializer(NonNullModelSerializer): - - class Meta: - model = AssetUpload - list_serializer_class = AssetUploadListSerializer - fields = [ - 'upload_id', - 'status', - 'created', - 'checksum_multihash', - 'completed', - 'aborted', - 'number_parts', - 'md5_parts', - 'urls', - 'ended', - 'parts', - 'update_interval', - 'content_encoding' - ] - - checksum_multihash = serializers.CharField( - source='checksum_multihash', - max_length=255, - required=True, - allow_blank=False, - validators=[validate_checksum_multihash_sha256] - ) - md5_parts = serializers.JSONField(required=True) - update_interval = serializers.IntegerField( - required=False, allow_null=False, min_value=-1, max_value=3600, default=-1 - ) - content_encoding = serializers.CharField( - required=False, - allow_null=False, - allow_blank=False, - min_length=1, - max_length=32, - default='', - validators=[validate_content_encoding] - ) - - # write only fields - ended = serializers.DateTimeField(write_only=True, required=False) - parts = serializers.ListField( - child=UploadPartSerializer(), write_only=True, allow_empty=False, required=False - ) - - # Read only fields - upload_id = serializers.CharField(read_only=True) - created = serializers.DateTimeField(read_only=True) - urls = serializers.JSONField(read_only=True) - completed = serializers.SerializerMethodField() - aborted = serializers.SerializerMethodField() - - def validate(self, attrs): - # get partial from kwargs (if partial true and no md5 : ok, if false no md5 : error) - # Check the md5 parts length - if attrs.get('md5_parts') is not None: - validate_md5_parts(attrs['md5_parts'], attrs['number_parts']) - elif not self.partial: - raise serializers.ValidationError( - detail={'md5_parts': _('md5_parts parameter is missing')}, code='missing' - ) - return attrs - - def get_completed(self, obj): - if obj.status == AssetUpload.Status.COMPLETED: - return isoformat(obj.ended) - return None - - def get_aborted(self, obj): - if obj.status == AssetUpload.Status.ABORTED: - return isoformat(obj.ended) - return None - - def get_fields(self): - fields = super().get_fields() - # This is a hack to allow fields with special characters - fields['file:checksum'] = fields.pop('checksum_multihash') - - # Older versions of the api still use different name - request = self.context.get('request') - if not is_api_version_1(request): - fields['checksum:multihash'] = fields.pop('file:checksum') - return fields - - -class AssetUploadPartsSerializer(serializers.Serializer): - '''S3 list_parts response serializer''' - - # pylint: disable=abstract-method - - class Meta: - list_serializer_class = AssetUploadListSerializer - - # Read only fields - parts = serializers.ListField( - source='Parts', child=UploadPartSerializer(), default=list, read_only=True - ) - - -class CollectionAssetUploadSerializer(NonNullModelSerializer): - - class Meta: - model = CollectionAssetUpload - list_serializer_class = AssetUploadListSerializer - fields = [ - 'upload_id', - 'status', - 'created', - 'checksum_multihash', - 'completed', - 'aborted', - 'number_parts', - 'md5_parts', - 'urls', - 'ended', - 'parts', - 'update_interval', - 'content_encoding' - ] - - checksum_multihash = serializers.CharField( - source='checksum_multihash', - max_length=255, - required=True, - allow_blank=False, - validators=[validate_checksum_multihash_sha256] - ) - md5_parts = serializers.JSONField(required=True) - update_interval = serializers.IntegerField( - required=False, allow_null=False, min_value=-1, max_value=3600, default=-1 - ) - content_encoding = serializers.CharField( - required=False, - allow_null=False, - allow_blank=False, - min_length=1, - max_length=32, - default='', - validators=[validate_content_encoding] - ) - - # write only fields - ended = serializers.DateTimeField(write_only=True, required=False) - parts = serializers.ListField( - child=UploadPartSerializer(), write_only=True, allow_empty=False, required=False - ) - - # Read only fields - upload_id = serializers.CharField(read_only=True) - created = serializers.DateTimeField(read_only=True) - urls = serializers.JSONField(read_only=True) - completed = serializers.SerializerMethodField() - aborted = serializers.SerializerMethodField() - - def validate(self, attrs): - # get partial from kwargs (if partial true and no md5 : ok, if false no md5 : error) - # Check the md5 parts length - if attrs.get('md5_parts') is not None: - validate_md5_parts(attrs['md5_parts'], attrs['number_parts']) - elif not self.partial: - raise serializers.ValidationError( - detail={'md5_parts': _('md5_parts parameter is missing')}, code='missing' - ) - return attrs - - def get_completed(self, obj): - if obj.status == CollectionAssetUpload.Status.COMPLETED: - return isoformat(obj.ended) - return None - - def get_aborted(self, obj): - if obj.status == CollectionAssetUpload.Status.ABORTED: - return isoformat(obj.ended) - return None - - def get_fields(self): - fields = super().get_fields() - # This is a hack to allow fields with special characters - fields['file:checksum'] = fields.pop('checksum_multihash') - - # Older versions of the api still use different name - request = self.context.get('request') - if not is_api_version_1(request): - fields['checksum:multihash'] = fields.pop('file:checksum') - return fields diff --git a/app/stac_api/serializers/serializers_utils.py b/app/stac_api/serializers/serializers_utils.py index 68b0a2b6..70ae5220 100644 --- a/app/stac_api/serializers/serializers_utils.py +++ b/app/stac_api/serializers/serializers_utils.py @@ -4,6 +4,7 @@ from rest_framework import serializers from rest_framework.utils.serializer_helpers import ReturnDict +from stac_api.utils import build_asset_href from stac_api.utils import get_browser_url from stac_api.utils import get_url @@ -311,3 +312,30 @@ def to_representation(self, data): def data(self): ret = super(serializers.ListSerializer, self).data return ReturnDict(ret, serializer=self) + + +class AssetsDictSerializer(DictSerializer): + '''Assets serializer list to dictionary + + This serializer returns an asset dictionary with the asset name as keys. + ''' + # pylint: disable=abstract-method + key_identifier = 'id' + + +class HrefField(serializers.Field): + '''Special Href field for Assets''' + + # pylint: disable=abstract-method + + def to_representation(self, value): + # build an absolute URL from the file path + request = self.context.get("request") + path = value.name + + if value.instance.is_external: + return path + return build_asset_href(request, path) + + def to_internal_value(self, data): + return data diff --git a/app/stac_api/serializers/upload.py b/app/stac_api/serializers/upload.py new file mode 100644 index 00000000..286d4a90 --- /dev/null +++ b/app/stac_api/serializers/upload.py @@ -0,0 +1,231 @@ +import logging + +from django.utils.translation import gettext_lazy as _ + +from rest_framework import serializers +from rest_framework.utils.serializer_helpers import ReturnDict + +from stac_api.models import AssetUpload +from stac_api.models import CollectionAssetUpload +from stac_api.serializers.serializers_utils import NonNullModelSerializer +from stac_api.utils import is_api_version_1 +from stac_api.utils import isoformat +from stac_api.validators import validate_checksum_multihash_sha256 +from stac_api.validators import validate_content_encoding +from stac_api.validators import validate_md5_parts + +logger = logging.getLogger(__name__) + + +class AssetUploadListSerializer(serializers.ListSerializer): + # pylint: disable=abstract-method + + def to_representation(self, data): + return {'uploads': super().to_representation(data)} + + @property + def data(self): + ret = super(serializers.ListSerializer, self).data + return ReturnDict(ret, serializer=self) + + +class UploadPartSerializer(serializers.Serializer): + '''This serializer is used to serialize the data from/to the S3 API. + ''' + # pylint: disable=abstract-method + etag = serializers.CharField(source='ETag', allow_blank=False, required=True) + part_number = serializers.IntegerField( + source='PartNumber', min_value=1, max_value=100, required=True, allow_null=False + ) + modified = serializers.DateTimeField(source='LastModified', required=False, allow_null=True) + size = serializers.IntegerField(source='Size', allow_null=True, required=False) + + +class AssetUploadSerializer(NonNullModelSerializer): + + class Meta: + model = AssetUpload + list_serializer_class = AssetUploadListSerializer + fields = [ + 'upload_id', + 'status', + 'created', + 'checksum_multihash', + 'completed', + 'aborted', + 'number_parts', + 'md5_parts', + 'urls', + 'ended', + 'parts', + 'update_interval', + 'content_encoding' + ] + + checksum_multihash = serializers.CharField( + source='checksum_multihash', + max_length=255, + required=True, + allow_blank=False, + validators=[validate_checksum_multihash_sha256] + ) + md5_parts = serializers.JSONField(required=True) + update_interval = serializers.IntegerField( + required=False, allow_null=False, min_value=-1, max_value=3600, default=-1 + ) + content_encoding = serializers.CharField( + required=False, + allow_null=False, + allow_blank=False, + min_length=1, + max_length=32, + default='', + validators=[validate_content_encoding] + ) + + # write only fields + ended = serializers.DateTimeField(write_only=True, required=False) + parts = serializers.ListField( + child=UploadPartSerializer(), write_only=True, allow_empty=False, required=False + ) + + # Read only fields + upload_id = serializers.CharField(read_only=True) + created = serializers.DateTimeField(read_only=True) + urls = serializers.JSONField(read_only=True) + completed = serializers.SerializerMethodField() + aborted = serializers.SerializerMethodField() + + def validate(self, attrs): + # get partial from kwargs (if partial true and no md5 : ok, if false no md5 : error) + # Check the md5 parts length + if attrs.get('md5_parts') is not None: + validate_md5_parts(attrs['md5_parts'], attrs['number_parts']) + elif not self.partial: + raise serializers.ValidationError( + detail={'md5_parts': _('md5_parts parameter is missing')}, code='missing' + ) + return attrs + + def get_completed(self, obj): + if obj.status == AssetUpload.Status.COMPLETED: + return isoformat(obj.ended) + return None + + def get_aborted(self, obj): + if obj.status == AssetUpload.Status.ABORTED: + return isoformat(obj.ended) + return None + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + return fields + + +class AssetUploadPartsSerializer(serializers.Serializer): + '''S3 list_parts response serializer''' + + # pylint: disable=abstract-method + + class Meta: + list_serializer_class = AssetUploadListSerializer + + # Read only fields + parts = serializers.ListField( + source='Parts', child=UploadPartSerializer(), default=list, read_only=True + ) + + +class CollectionAssetUploadSerializer(NonNullModelSerializer): + + class Meta: + model = CollectionAssetUpload + list_serializer_class = AssetUploadListSerializer + fields = [ + 'upload_id', + 'status', + 'created', + 'checksum_multihash', + 'completed', + 'aborted', + 'number_parts', + 'md5_parts', + 'urls', + 'ended', + 'parts', + 'update_interval', + 'content_encoding' + ] + + checksum_multihash = serializers.CharField( + source='checksum_multihash', + max_length=255, + required=True, + allow_blank=False, + validators=[validate_checksum_multihash_sha256] + ) + md5_parts = serializers.JSONField(required=True) + update_interval = serializers.IntegerField( + required=False, allow_null=False, min_value=-1, max_value=3600, default=-1 + ) + content_encoding = serializers.CharField( + required=False, + allow_null=False, + allow_blank=False, + min_length=1, + max_length=32, + default='', + validators=[validate_content_encoding] + ) + + # write only fields + ended = serializers.DateTimeField(write_only=True, required=False) + parts = serializers.ListField( + child=UploadPartSerializer(), write_only=True, allow_empty=False, required=False + ) + + # Read only fields + upload_id = serializers.CharField(read_only=True) + created = serializers.DateTimeField(read_only=True) + urls = serializers.JSONField(read_only=True) + completed = serializers.SerializerMethodField() + aborted = serializers.SerializerMethodField() + + def validate(self, attrs): + # get partial from kwargs (if partial true and no md5 : ok, if false no md5 : error) + # Check the md5 parts length + if attrs.get('md5_parts') is not None: + validate_md5_parts(attrs['md5_parts'], attrs['number_parts']) + elif not self.partial: + raise serializers.ValidationError( + detail={'md5_parts': _('md5_parts parameter is missing')}, code='missing' + ) + return attrs + + def get_completed(self, obj): + if obj.status == CollectionAssetUpload.Status.COMPLETED: + return isoformat(obj.ended) + return None + + def get_aborted(self, obj): + if obj.status == CollectionAssetUpload.Status.ABORTED: + return isoformat(obj.ended) + return None + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + return fields diff --git a/app/stac_api/views/upload.py b/app/stac_api/views/upload.py index 20580cf8..7927b6fc 100644 --- a/app/stac_api/views/upload.py +++ b/app/stac_api/views/upload.py @@ -23,9 +23,9 @@ from stac_api.models import CollectionAssetUpload from stac_api.pagination import ExtApiPagination from stac_api.s3_multipart_upload import MultipartUpload -from stac_api.serializers.serializers import AssetUploadPartsSerializer -from stac_api.serializers.serializers import AssetUploadSerializer -from stac_api.serializers.serializers import CollectionAssetUploadSerializer +from stac_api.serializers.upload import AssetUploadPartsSerializer +from stac_api.serializers.upload import AssetUploadSerializer +from stac_api.serializers.upload import CollectionAssetUploadSerializer from stac_api.utils import get_asset_path from stac_api.utils import get_collection_asset_path from stac_api.utils import select_s3_bucket diff --git a/app/tests/tests_09/test_serializer_asset_upload.py b/app/tests/tests_09/test_serializer_asset_upload.py index b1287734..e5a54e43 100644 --- a/app/tests/tests_09/test_serializer_asset_upload.py +++ b/app/tests/tests_09/test_serializer_asset_upload.py @@ -7,7 +7,7 @@ from rest_framework.exceptions import ValidationError from stac_api.models import AssetUpload -from stac_api.serializers.serializers import AssetUploadSerializer +from stac_api.serializers.upload import AssetUploadSerializer from stac_api.utils import get_sha256_multihash from tests.tests_09.base_test import STAC_BASE_V diff --git a/app/tests/tests_10/test_serializer_asset_upload.py b/app/tests/tests_10/test_serializer_asset_upload.py index 9c7c30ab..c5d359a2 100644 --- a/app/tests/tests_10/test_serializer_asset_upload.py +++ b/app/tests/tests_10/test_serializer_asset_upload.py @@ -7,7 +7,7 @@ from rest_framework.exceptions import ValidationError from stac_api.models import AssetUpload -from stac_api.serializers.serializers import AssetUploadSerializer +from stac_api.serializers.upload import AssetUploadSerializer from stac_api.utils import get_sha256_multihash from tests.tests_10.base_test import StacBaseTestCase From 6236d11a3d186d2df7cec4070077853f9ffb1ec2 Mon Sep 17 00:00:00 2001 From: Benjamin Sugden Date: Wed, 11 Sep 2024 14:49:53 +0200 Subject: [PATCH 15/16] Rename serializer files File names should not be the same as their folder --- app/stac_api/serializers/collection.py | 12 ++++++------ .../serializers/{serializers.py => general.py} | 0 app/stac_api/serializers/item.py | 12 ++++++------ app/stac_api/serializers/upload.py | 2 +- .../serializers/{serializers_utils.py => utils.py} | 0 app/stac_api/views/collection.py | 2 +- app/stac_api/views/general.py | 6 +++--- app/stac_api/views/item.py | 2 +- app/stac_api/views/mixins.py | 2 +- scripts/fill_local_db.py | 8 ++++---- 10 files changed, 23 insertions(+), 23 deletions(-) rename app/stac_api/serializers/{serializers.py => general.py} (100%) rename app/stac_api/serializers/{serializers_utils.py => utils.py} (100%) diff --git a/app/stac_api/serializers/collection.py b/app/stac_api/serializers/collection.py index b01cd722..655d6848 100644 --- a/app/stac_api/serializers/collection.py +++ b/app/stac_api/serializers/collection.py @@ -8,12 +8,12 @@ from stac_api.models import CollectionAsset from stac_api.models import CollectionLink from stac_api.models import Provider -from stac_api.serializers.serializers_utils import AssetsDictSerializer -from stac_api.serializers.serializers_utils import HrefField -from stac_api.serializers.serializers_utils import NonNullModelSerializer -from stac_api.serializers.serializers_utils import UpsertModelSerializerMixin -from stac_api.serializers.serializers_utils import get_relation_links -from stac_api.serializers.serializers_utils import update_or_create_links +from stac_api.serializers.utils import AssetsDictSerializer +from stac_api.serializers.utils import HrefField +from stac_api.serializers.utils import NonNullModelSerializer +from stac_api.serializers.utils import UpsertModelSerializerMixin +from stac_api.serializers.utils import get_relation_links +from stac_api.serializers.utils import update_or_create_links from stac_api.utils import get_stac_version from stac_api.utils import is_api_version_1 from stac_api.utils import isoformat diff --git a/app/stac_api/serializers/serializers.py b/app/stac_api/serializers/general.py similarity index 100% rename from app/stac_api/serializers/serializers.py rename to app/stac_api/serializers/general.py diff --git a/app/stac_api/serializers/item.py b/app/stac_api/serializers/item.py index eb067475..b0cb9fe2 100644 --- a/app/stac_api/serializers/item.py +++ b/app/stac_api/serializers/item.py @@ -9,12 +9,12 @@ from stac_api.models import Asset from stac_api.models import Item from stac_api.models import ItemLink -from stac_api.serializers.serializers_utils import AssetsDictSerializer -from stac_api.serializers.serializers_utils import HrefField -from stac_api.serializers.serializers_utils import NonNullModelSerializer -from stac_api.serializers.serializers_utils import UpsertModelSerializerMixin -from stac_api.serializers.serializers_utils import get_relation_links -from stac_api.serializers.serializers_utils import update_or_create_links +from stac_api.serializers.utils import AssetsDictSerializer +from stac_api.serializers.utils import HrefField +from stac_api.serializers.utils import NonNullModelSerializer +from stac_api.serializers.utils import UpsertModelSerializerMixin +from stac_api.serializers.utils import get_relation_links +from stac_api.serializers.utils import update_or_create_links from stac_api.utils import get_stac_version from stac_api.utils import is_api_version_1 from stac_api.validators import normalize_and_validate_media_type diff --git a/app/stac_api/serializers/upload.py b/app/stac_api/serializers/upload.py index 286d4a90..9a42887c 100644 --- a/app/stac_api/serializers/upload.py +++ b/app/stac_api/serializers/upload.py @@ -7,7 +7,7 @@ from stac_api.models import AssetUpload from stac_api.models import CollectionAssetUpload -from stac_api.serializers.serializers_utils import NonNullModelSerializer +from stac_api.serializers.utils import NonNullModelSerializer from stac_api.utils import is_api_version_1 from stac_api.utils import isoformat from stac_api.validators import validate_checksum_multihash_sha256 diff --git a/app/stac_api/serializers/serializers_utils.py b/app/stac_api/serializers/utils.py similarity index 100% rename from app/stac_api/serializers/serializers_utils.py rename to app/stac_api/serializers/utils.py diff --git a/app/stac_api/views/collection.py b/app/stac_api/views/collection.py index 99cee675..1a14d4be 100644 --- a/app/stac_api/views/collection.py +++ b/app/stac_api/views/collection.py @@ -12,7 +12,7 @@ from stac_api.models import CollectionAsset from stac_api.serializers.collection import CollectionAssetSerializer from stac_api.serializers.collection import CollectionSerializer -from stac_api.serializers.serializers_utils import get_relation_links +from stac_api.serializers.utils import get_relation_links from stac_api.utils import get_collection_asset_path from stac_api.validators_view import validate_collection from stac_api.validators_view import validate_renaming diff --git a/app/stac_api/views/general.py b/app/stac_api/views/general.py index 566453d0..3b4d3a71 100644 --- a/app/stac_api/views/general.py +++ b/app/stac_api/views/general.py @@ -17,10 +17,10 @@ from stac_api.models import Item from stac_api.models import LandingPage from stac_api.pagination import GetPostCursorPagination +from stac_api.serializers.general import ConformancePageSerializer +from stac_api.serializers.general import LandingPageSerializer from stac_api.serializers.item import ItemSerializer -from stac_api.serializers.serializers import ConformancePageSerializer -from stac_api.serializers.serializers import LandingPageSerializer -from stac_api.serializers.serializers_utils import get_relation_links +from stac_api.serializers.utils import get_relation_links from stac_api.utils import call_calculate_extent from stac_api.utils import harmonize_post_get_for_search from stac_api.utils import is_api_version_1 diff --git a/app/stac_api/views/item.py b/app/stac_api/views/item.py index 87e05f56..4efa73cf 100644 --- a/app/stac_api/views/item.py +++ b/app/stac_api/views/item.py @@ -16,7 +16,7 @@ from stac_api.models import Item from stac_api.serializers.item import AssetSerializer from stac_api.serializers.item import ItemSerializer -from stac_api.serializers.serializers_utils import get_relation_links +from stac_api.serializers.utils import get_relation_links from stac_api.utils import get_asset_path from stac_api.utils import utc_aware from stac_api.validators_view import validate_collection diff --git a/app/stac_api/views/mixins.py b/app/stac_api/views/mixins.py index 8d75b186..8e8f1569 100644 --- a/app/stac_api/views/mixins.py +++ b/app/stac_api/views/mixins.py @@ -11,7 +11,7 @@ from rest_framework import status from rest_framework.response import Response -from stac_api.serializers.serializers_utils import get_parent_link +from stac_api.serializers.utils import get_parent_link from stac_api.utils import get_dynamic_max_age_value from stac_api.utils import get_link diff --git a/scripts/fill_local_db.py b/scripts/fill_local_db.py index d219b096..6db19f8e 100644 --- a/scripts/fill_local_db.py +++ b/scripts/fill_local_db.py @@ -20,10 +20,10 @@ from rest_framework.renderers import JSONRenderer from rest_framework.parsers import JSONParser from stac_api.models import * -from stac_api.serializers.serializers import CollectionSerializer -from stac_api.serializers.serializers import CollectionSerializer -from stac_api.serializers.serializers import LinkSerializer -from stac_api.serializers.serializers import ProviderSerializer +from stac_api.serializers.general import CollectionSerializer +from stac_api.serializers.general import CollectionSerializer +from stac_api.serializers.general import LinkSerializer +from stac_api.serializers.general import ProviderSerializer # create link instances for testing From 67a0b8e3142d06292ed8ae419ea7be44909f41f8 Mon Sep 17 00:00:00 2001 From: Brice Schaffner Date: Thu, 12 Sep 2024 15:25:54 +0200 Subject: [PATCH 16/16] PB-932: Avoid UTF-8 truncation in logging The request body is in utf-8 which needs either 1 byte or 2 bytes per character depending on the character, so we need to decode first before truncating otherwise we might truncate in the middle of a utf-8 character and it will break the decode. --- app/middleware/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/middleware/logging.py b/app/middleware/logging.py index 5fa71038..6b2d7c3e 100644 --- a/app/middleware/logging.py +++ b/app/middleware/logging.py @@ -28,7 +28,7 @@ def __call__(self, request): ] and request.content_type == "application/json" and not request.path.startswith( '/api/stac/admin' ): - extra["request.payload"] = request.body[:200].decode() + extra["request.payload"] = request.body.decode()[:200] logger.debug( "Request %s %s?%s",