diff --git a/app/config/urls.py b/app/config/urls.py index a0eb4c8a..f738e8f9 100644 --- a/app/config/urls.py +++ b/app/config/urls.py @@ -38,10 +38,11 @@ def checker(request): if settings.DEBUG: import debug_toolbar - from stac_api.views_test import TestAssetUpsertHttp500 - from stac_api.views_test import TestCollectionUpsertHttp500 - from stac_api.views_test import TestHttp500 - from stac_api.views_test import TestItemUpsertHttp500 + from stac_api.views.test import TestAssetUpsertHttp500 + from stac_api.views.test import TestCollectionAssetUpsertHttp500 + from stac_api.views.test import TestCollectionUpsertHttp500 + from stac_api.views.test import TestHttp500 + from stac_api.views.test import TestItemUpsertHttp500 urlpatterns = [ path('__debug__/', include(debug_toolbar.urls)), @@ -61,6 +62,11 @@ def checker(request): TestAssetUpsertHttp500.as_view(), name='test-asset-detail-http-500' ), + path( + 'tests/test_collection_asset_upsert_http_500//', + TestCollectionAssetUpsertHttp500.as_view(), + name='test-collection-asset-detail-http-500' + ), # Add v0.9 namespace to test routes. path( 'tests/v0.9/test_asset_upsert_http_500///', diff --git a/app/middleware/logging.py b/app/middleware/logging.py index 5fa71038..6b2d7c3e 100644 --- a/app/middleware/logging.py +++ b/app/middleware/logging.py @@ -28,7 +28,7 @@ def __call__(self, request): ] and request.content_type == "application/json" and not request.path.startswith( '/api/stac/admin' ): - extra["request.payload"] = request.body[:200].decode() + extra["request.payload"] = request.body.decode()[:200] logger.debug( "Request %s %s?%s", diff --git a/app/stac_api/management/commands/list_asset_uploads.py b/app/stac_api/management/commands/list_asset_uploads.py index 19c90ece..079ac3ec 100644 --- a/app/stac_api/management/commands/list_asset_uploads.py +++ b/app/stac_api/management/commands/list_asset_uploads.py @@ -6,7 +6,7 @@ from stac_api.models import AssetUpload from stac_api.s3_multipart_upload import MultipartUpload -from stac_api.serializers import AssetUploadSerializer +from stac_api.serializers.upload import AssetUploadSerializer from stac_api.utils import CommandHandler from stac_api.utils import get_asset_path diff --git a/app/stac_api/management/commands/profile_item_serializer.py b/app/stac_api/management/commands/profile_item_serializer.py index 84ce4cf8..47f2ba79 100644 --- a/app/stac_api/management/commands/profile_item_serializer.py +++ b/app/stac_api/management/commands/profile_item_serializer.py @@ -20,7 +20,7 @@ class Handler(CommandHandler): def profiling(self): # pylint: disable=import-outside-toplevel,possibly-unused-variable - from stac_api.serializers import ItemSerializer + from stac_api.serializers.item import ItemSerializer collection_id = self.options["collection"] qs = Item.objects.filter(collection__name=collection_id ).prefetch_related('assets', 'links')[:self.options['limit']] diff --git a/app/stac_api/management/commands/profile_serializer_vs_no_drf.py b/app/stac_api/management/commands/profile_serializer_vs_no_drf.py index 64da7525..d4b1ddca 100644 --- a/app/stac_api/management/commands/profile_serializer_vs_no_drf.py +++ b/app/stac_api/management/commands/profile_serializer_vs_no_drf.py @@ -21,7 +21,7 @@ def profiling(self): # pylint: disable=import-outside-toplevel,possibly-unused-variable self.print('Starting profiling') - from stac_api.serializers import ItemSerializer + from stac_api.serializers.item import ItemSerializer def serialize(qs): return { diff --git a/app/stac_api/migrations/0050_collectionassetupload_and_more.py b/app/stac_api/migrations/0050_collectionassetupload_and_more.py new file mode 100644 index 00000000..1238a62a --- /dev/null +++ b/app/stac_api/migrations/0050_collectionassetupload_and_more.py @@ -0,0 +1,147 @@ +# Generated by Django 5.0.8 on 2024-09-10 12:45 + +import pgtrigger.compiler +import pgtrigger.migrations + +import django.core.serializers.json +import django.core.validators +import django.db.models.deletion +from django.db import migrations +from django.db import models + +import stac_api.models + + +class Migration(migrations.Migration): + + dependencies = [ + ('stac_api', '0049_item_properties_expires'), + ] + + operations = [ + migrations.AlterField( + model_name='assetupload', + name='update_interval', + field=models.IntegerField( + default=-1, + help_text= + 'Interval in seconds in which the asset data is updated. -1 means that the data is not on a regular basis updated. This field can only be set via the API.', + validators=[django.core.validators.MinValueValidator(-1)] + ), + ), + migrations.CreateModel( + name='CollectionAssetUpload', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('upload_id', models.CharField(max_length=255)), + ( + 'status', + models.CharField( + choices=[(None, ''), ('in-progress', 'In Progress'), + ('completed', 'Completed'), ('aborted', 'Aborted')], + default='in-progress', + max_length=32 + ) + ), + ( + 'number_parts', + models.IntegerField( + validators=[ + django.core.validators.MinValueValidator(1), + django.core.validators.MaxValueValidator(100) + ] + ) + ), + ( + 'md5_parts', + models.JSONField( + editable=False, encoder=django.core.serializers.json.DjangoJSONEncoder + ) + ), + ( + 'urls', + models.JSONField( + blank=True, + default=list, + encoder=django.core.serializers.json.DjangoJSONEncoder + ) + ), + ('created', models.DateTimeField(auto_now_add=True)), + ('ended', models.DateTimeField(blank=True, default=None, null=True)), + ('checksum_multihash', models.CharField(max_length=255)), + ('etag', models.CharField(default=stac_api.models.compute_etag, max_length=56)), + ( + 'update_interval', + models.IntegerField( + default=-1, + help_text= + 'Interval in seconds in which the asset data is updated. -1 means that the data is not on a regular basis updated. This field can only be set via the API.', + validators=[django.core.validators.MinValueValidator(-1)] + ) + ), + ( + 'content_encoding', + models.CharField( + blank=True, + choices=[(None, ''), ('gzip', 'Gzip'), ('br', 'Br')], + default='', + max_length=32 + ) + ), + ( + 'asset', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name='+', + to='stac_api.collectionasset' + ) + ), + ], + ), + migrations.AddConstraint( + model_name='collectionassetupload', + constraint=models.UniqueConstraint( + fields=('asset', 'upload_id'), + name='unique_asset_upload_collection_asset_upload_id' + ), + ), + migrations.AddConstraint( + model_name='collectionassetupload', + constraint=models.UniqueConstraint( + condition=models.Q(('status', 'in-progress')), + fields=('asset', 'status'), + name='unique_asset_upload_in_progress' + ), + ), + pgtrigger.migrations.AddTrigger( + model_name='collectionassetupload', + trigger=pgtrigger.compiler.Trigger( + name='add_asset_upload_trigger', + sql=pgtrigger.compiler.UpsertTriggerSql( + func= + '\n -- update AssetUpload auto variable\n NEW.etag = public.gen_random_uuid();\n\n RETURN NEW;\n ', + hash='5f51ec3c72c4d9fbe6b81d2fd881dd5228dc80bf', + operation='INSERT', + pgid='pgtrigger_add_asset_upload_trigger_8330c', + table='stac_api_collectionassetupload', + when='BEFORE' + ) + ), + ), + pgtrigger.migrations.AddTrigger( + model_name='collectionassetupload', + trigger=pgtrigger.compiler.Trigger( + name='update_asset_upload_trigger', + sql=pgtrigger.compiler.UpsertTriggerSql( + condition='WHEN (OLD.* IS DISTINCT FROM NEW.*)', + func= + '\n -- update AssetUpload auto variable\n NEW.etag = public.gen_random_uuid();\n\n RETURN NEW;\n ', + hash='0a7f1aa8f8c0bb2c413a7ce626f75c8da5bf4b6d', + operation='UPDATE', + pgid='pgtrigger_update_asset_upload_trigger_8d012', + table='stac_api_collectionassetupload', + when='BEFORE' + ) + ), + ), + ] diff --git a/app/stac_api/models.py b/app/stac_api/models.py index e192e22c..c4ef6de3 100644 --- a/app/stac_api/models.py +++ b/app/stac_api/models.py @@ -710,19 +710,10 @@ def get_asset_path(self): return get_collection_asset_path(self.collection, self.name) -class AssetUpload(models.Model): +class BaseAssetUpload(models.Model): class Meta: - constraints = [ - models.UniqueConstraint(fields=['asset', 'upload_id'], name='unique_together'), - # Make sure that there is only one asset upload in progress per asset - models.UniqueConstraint( - fields=['asset', 'status'], - condition=Q(status='in-progress'), - name='unique_in_progress' - ) - ] - triggers = generates_asset_upload_triggers() + abstract = True class Status(models.TextChoices): # pylint: disable=invalid-name @@ -741,7 +732,6 @@ class ContentEncoding(models.TextChoices): # using BigIntegerField as primary_key to deal with the expected large number of assets. id = models.BigAutoField(primary_key=True) - asset = models.ForeignKey(Asset, related_name='+', on_delete=models.CASCADE) upload_id = models.CharField(max_length=255, blank=False, null=False) status = models.CharField( choices=Status.choices, max_length=32, default=Status.IN_PROGRESS, blank=False, null=False @@ -765,8 +755,8 @@ class ContentEncoding(models.TextChoices): null=False, blank=False, validators=[MinValueValidator(-1)], - help_text="Interval in seconds in which the asset data is updated." - "-1 means that the data is not on a regular basis updated." + help_text="Interval in seconds in which the asset data is updated. " + "-1 means that the data is not on a regular basis updated. " "This field can only be set via the API." ) @@ -777,6 +767,23 @@ class ContentEncoding(models.TextChoices): # Custom Manager that preselects the collection objects = AssetUploadManager() + +class AssetUpload(BaseAssetUpload): + + class Meta: + constraints = [ + models.UniqueConstraint(fields=['asset', 'upload_id'], name='unique_together'), + # Make sure that there is only one asset upload in progress per asset + models.UniqueConstraint( + fields=['asset', 'status'], + condition=Q(status='in-progress'), + name='unique_in_progress' + ) + ] + triggers = generates_asset_upload_triggers() + + asset = models.ForeignKey(Asset, related_name='+', on_delete=models.CASCADE) + def update_asset_from_upload(self): '''Updating the asset's file:checksum and update_interval from the upload @@ -804,6 +811,51 @@ def update_asset_from_upload(self): self.asset.save() +class CollectionAssetUpload(BaseAssetUpload): + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=['asset', 'upload_id'], + name='unique_asset_upload_collection_asset_upload_id' + ), + # Make sure that there is only one upload in progress per collection asset + models.UniqueConstraint( + fields=['asset', 'status'], + condition=Q(status='in-progress'), + name='unique_asset_upload_in_progress' + ) + ] + triggers = generates_asset_upload_triggers() + + asset = models.ForeignKey(CollectionAsset, related_name='+', on_delete=models.CASCADE) + + def update_asset_from_upload(self): + '''Updating the asset's file:checksum and update_interval from the upload + + When the upload is completed, the new file:checksum and update interval from the upload + is set to its asset parent. + ''' + logger.debug( + 'Updating collection asset %s file:checksum from %s to %s and update_interval ' + 'from %d to %d due to upload complete', + self.asset.name, + self.asset.checksum_multihash, + self.checksum_multihash, + self.asset.update_interval, + self.update_interval, + extra={ + 'upload_id': self.upload_id, + 'asset': self.asset.name, + 'collection': self.asset.collection.name + } + ) + + self.asset.checksum_multihash = self.checksum_multihash + self.asset.update_interval = self.update_interval + self.asset.save() + + class CountBase(models.Model): '''CountBase tables are used to help calculate the summary on a collection. This is only performant if the distinct number of values is small, e.g. we currently only have diff --git a/app/stac_api/s3_multipart_upload.py b/app/stac_api/s3_multipart_upload.py index 5b1239f1..0c5a10c8 100644 --- a/app/stac_api/s3_multipart_upload.py +++ b/app/stac_api/s3_multipart_upload.py @@ -12,6 +12,8 @@ from rest_framework import serializers from stac_api.exceptions import UploadNotInProgressError +from stac_api.models import Asset +from stac_api.models import CollectionAsset from stac_api.utils import AVAILABLE_S3_BUCKETS from stac_api.utils import get_s3_cache_control_value from stac_api.utils import get_s3_client @@ -65,6 +67,24 @@ def list_multipart_uploads(self, key=None, limit=100, start=None): response.get('NextUploadIdMarker', None), ) + def log_extra(self, asset: Asset | CollectionAsset, upload_id=None, parts=None): + if isinstance(asset, Asset): + log_extra = { + 'collection': asset.item.collection.name, + 'item': asset.item.name, + 'asset': asset.name, + } + else: + log_extra = { + 'collection': asset.collection.name, + 'asset': asset.name, + } + if upload_id is not None: + log_extra['upload_id'] = upload_id + if parts is not None: + log_extra['parts'] = parts + return log_extra + def create_multipart_upload( self, key, asset, checksum_multihash, update_interval, content_encoding ): @@ -99,11 +119,7 @@ def create_multipart_upload( CacheControl=get_s3_cache_control_value(update_interval), ContentType=asset.media_type, **extra_params, - log_extra={ - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name - } + log_extra=self.log_extra(asset) ) logger.info( 'S3 Multipart upload successfully created: upload_id=%s', @@ -148,12 +164,7 @@ def create_presigned_url(self, key, asset, part, upload_id, part_md5): Params=params, ExpiresIn=settings.AWS_PRESIGNED_URL_EXPIRES, HttpMethod='PUT', - log_extra={ - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name, - 'upload_id': upload_id - } + log_extra=self.log_extra(asset, upload_id=upload_id) ) logger.info( @@ -191,13 +202,7 @@ def complete_multipart_upload(self, key, asset, parts, upload_id): Key=key, MultipartUpload={'Parts': parts}, UploadId=upload_id, - log_extra={ - 'parts': parts, - 'upload_id': upload_id, - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name - } + log_extra=self.log_extra(asset, upload_id=upload_id, parts=parts) ) except ClientError as error: raise serializers.ValidationError(str(error)) from None @@ -271,12 +276,7 @@ def list_upload_parts(self, key, asset, upload_id, limit, offset): UploadId=upload_id, MaxParts=limit, PartNumberMarker=offset, - log_extra={ - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name, - 'upload_id': upload_id - } + log_extra=self.log_extra(asset, upload_id=upload_id) ) return response, response.get('IsTruncated', False) diff --git a/app/stac_api/serializers.py b/app/stac_api/serializers.py deleted file mode 100644 index 343f059e..00000000 --- a/app/stac_api/serializers.py +++ /dev/null @@ -1,928 +0,0 @@ -import logging -from collections import OrderedDict -from urllib.parse import urlparse - -from django.conf import settings -from django.contrib.gis.geos import GEOSGeometry -from django.core.exceptions import ValidationError as CoreValidationError -from django.utils.translation import gettext_lazy as _ - -from rest_framework import serializers -from rest_framework.utils.serializer_helpers import ReturnDict -from rest_framework.validators import UniqueValidator -from rest_framework_gis import serializers as gis_serializers - -from stac_api.models import Asset -from stac_api.models import AssetUpload -from stac_api.models import Collection -from stac_api.models import CollectionLink -from stac_api.models import Item -from stac_api.models import ItemLink -from stac_api.models import LandingPage -from stac_api.models import LandingPageLink -from stac_api.models import Provider -from stac_api.serializers_utils import DictSerializer -from stac_api.serializers_utils import NonNullModelSerializer -from stac_api.serializers_utils import UpsertModelSerializerMixin -from stac_api.serializers_utils import get_relation_links -from stac_api.serializers_utils import update_or_create_links -from stac_api.utils import build_asset_href -from stac_api.utils import get_browser_url -from stac_api.utils import get_stac_version -from stac_api.utils import get_url -from stac_api.utils import is_api_version_1 -from stac_api.utils import isoformat -from stac_api.validators import normalize_and_validate_media_type -from stac_api.validators import validate_asset_name -from stac_api.validators import validate_asset_name_with_media_type -from stac_api.validators import validate_checksum_multihash_sha256 -from stac_api.validators import validate_content_encoding -from stac_api.validators import validate_geoadmin_variant -from stac_api.validators import validate_href_url -from stac_api.validators import validate_item_properties_datetimes -from stac_api.validators import validate_md5_parts -from stac_api.validators import validate_name -from stac_api.validators_serializer import validate_json_payload -from stac_api.validators_serializer import validate_uniqueness_and_create - -logger = logging.getLogger(__name__) - - -class LandingPageLinkSerializer(serializers.ModelSerializer): - - class Meta: - model = LandingPageLink - fields = ['href', 'rel', 'link_type', 'title'] - - -class ConformancePageSerializer(serializers.ModelSerializer): - - class Meta: - model = LandingPage - fields = ['conformsTo'] - - -class LandingPageSerializer(serializers.ModelSerializer): - - class Meta: - model = LandingPage - fields = ['id', 'title', 'description', 'links', 'stac_version', 'conformsTo'] - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - id = serializers.CharField( - max_length=255, - source="name", - validators=[validate_name, UniqueValidator(queryset=LandingPage.objects.all())] - ) - # Read only fields - links = LandingPageLinkSerializer(many=True, read_only=True) - stac_version = serializers.SerializerMethodField() - - def get_stac_version(self, obj): - return get_stac_version(self.context.get('request')) - - def to_representation(self, instance): - representation = super().to_representation(instance) - request = self.context.get("request") - - # Add hardcoded value Catalog to response to conform to stac spec v1. - representation['type'] = "Catalog" - - # Remove property on older versions - if not is_api_version_1(request): - del representation['type'] - - version = request.resolver_match.namespace - spec_base = f'{urlparse(settings.STATIC_SPEC_URL).path.strip(' / ')}/{version}' - # Add auto links - # We use OrderedDict, although it is not necessary, because the default serializer/model for - # links already uses OrderedDict, this way we keep consistency between auto link and user - # link - representation['links'][:0] = [ - OrderedDict([ - ('rel', 'root'), - ('href', get_url(request, 'landing-page')), - ("type", "application/json"), - ]), - OrderedDict([ - ('rel', 'self'), - ('href', get_url(request, 'landing-page')), - ("type", "application/json"), - ("title", "This document"), - ]), - OrderedDict([ - ("rel", "service-doc"), - ("href", request.build_absolute_uri(f"/{spec_base}/api.html")), - ("type", "text/html"), - ("title", "The API documentation"), - ]), - OrderedDict([ - ("rel", "service-desc"), - ("href", request.build_absolute_uri(f"/{spec_base}/openapi.yaml")), - ("type", "application/vnd.oai.openapi+yaml;version=3.0"), - ("title", "The OPENAPI description of the service"), - ]), - OrderedDict([ - ("rel", "conformance"), - ("href", get_url(request, 'conformance')), - ("type", "application/json"), - ("title", "OGC API conformance classes implemented by this server"), - ]), - OrderedDict([ - ('rel', 'data'), - ('href', get_url(request, 'collections-list')), - ("type", "application/json"), - ("title", "Information about the feature collections"), - ]), - OrderedDict([ - ("href", get_url(request, 'search-list')), - ("rel", "search"), - ("method", "GET"), - ("type", "application/json"), - ("title", "Search across feature collections"), - ]), - OrderedDict([ - ("href", get_url(request, 'search-list')), - ("rel", "search"), - ("method", "POST"), - ("type", "application/json"), - ("title", "Search across feature collections"), - ]), - OrderedDict([ - ("href", get_browser_url(request, 'browser-catalog')), - ("rel", "alternate"), - ("type", "text/html"), - ("title", "STAC Browser"), - ]), - ] - return representation - - -class ProviderSerializer(NonNullModelSerializer): - - class Meta: - model = Provider - fields = ['name', 'roles', 'url', 'description'] - - -class CollectionLinkSerializer(NonNullModelSerializer): - - class Meta: - model = CollectionLink - fields = ['href', 'rel', 'title', 'type'] - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - type = serializers.CharField( - required=False, allow_blank=True, max_length=150, source="link_type" - ) - - -class ItemLinkSerializer(NonNullModelSerializer): - - class Meta: - model = ItemLink - fields = ['href', 'rel', 'title', 'type'] - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - type = serializers.CharField( - required=False, allow_blank=True, max_length=255, source="link_type" - ) - - -class ItemsPropertiesSerializer(serializers.Serializer): - # pylint: disable=abstract-method - # ItemsPropertiesSerializer is a nested serializer and don't directly create/write instances - # therefore we don't need to implement the super method create() and update() - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - datetime = serializers.DateTimeField(source='properties_datetime', required=False, default=None) - start_datetime = serializers.DateTimeField( - source='properties_start_datetime', required=False, default=None - ) - end_datetime = serializers.DateTimeField( - source='properties_end_datetime', required=False, default=None - ) - title = serializers.CharField( - source='properties_title', - required=False, - allow_blank=False, - allow_null=True, - max_length=255, - default=None - ) - created = serializers.DateTimeField(read_only=True) - updated = serializers.DateTimeField(read_only=True) - expires = serializers.DateTimeField(source='properties_expires', required=False, default=None) - - -class BboxSerializer(gis_serializers.GeoFeatureModelSerializer): - - class Meta: - model = Item - geo_field = "geometry" - auto_bbox = True - fields = ['geometry'] - - def to_representation(self, instance): - python_native = super().to_representation(instance) - return python_native['bbox'] - - -class AssetsDictSerializer(DictSerializer): - '''Assets serializer list to dictionary - - This serializer returns an asset dictionary with the asset name as keys. - ''' - # pylint: disable=abstract-method - key_identifier = 'id' - - -class HrefField(serializers.Field): - '''Special Href field for Assets''' - - # pylint: disable=abstract-method - - def to_representation(self, value): - # build an absolute URL from the file path - request = self.context.get("request") - path = value.name - - if value.instance.is_external: - return path - return build_asset_href(request, path) - - def to_internal_value(self, data): - return data - - -class AssetBaseSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): - '''Asset serializer base class - ''' - - class Meta: - model = Asset - fields = [ - 'id', - 'title', - 'type', - 'href', - 'description', - 'eo_gsd', - 'roles', - 'geoadmin_lang', - 'geoadmin_variant', - 'proj_epsg', - 'checksum_multihash', - 'created', - 'updated', - ] - validators = [] # Remove a default "unique together" constraint. - # (see: - # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - id = serializers.CharField(source='name', max_length=255, validators=[validate_asset_name]) - title = serializers.CharField( - required=False, max_length=255, allow_null=True, allow_blank=False - ) - description = serializers.CharField(required=False, allow_blank=False, allow_null=True) - # Can't be a ChoiceField, as the validate method normalizes the MIME string only after it - # is read. Consistency is nevertheless guaranteed by the validate() and validate_type() methods. - type = serializers.CharField( - source='media_type', required=True, allow_null=False, allow_blank=False - ) - # Here we need to explicitely define these fields with the source, because they are renamed - # in the get_fields() method - eo_gsd = serializers.FloatField(source='eo_gsd', required=False, allow_null=True) - geoadmin_lang = serializers.ChoiceField( - source='geoadmin_lang', - choices=Asset.Language.values, - required=False, - allow_null=True, - allow_blank=False - ) - geoadmin_variant = serializers.CharField( - source='geoadmin_variant', - max_length=25, - allow_blank=False, - allow_null=True, - required=False, - validators=[validate_geoadmin_variant] - ) - proj_epsg = serializers.IntegerField(source='proj_epsg', allow_null=True, required=False) - # read only fields - checksum_multihash = serializers.CharField(source='checksum_multihash', read_only=True) - href = HrefField(source='file', required=False) - created = serializers.DateTimeField(read_only=True) - updated = serializers.DateTimeField(read_only=True) - - # helper variable to provide the collection for upsert validation - # see views.AssetDetail.perform_upsert - collection = None - - def create(self, validated_data): - asset = validate_uniqueness_and_create(Asset, validated_data) - return asset - - def update_or_create(self, look_up, validated_data): - """ - Update or create the asset object selected by kwargs and return the instance. - When no asset object matching the kwargs selection, a new asset is created. - Args: - validated_data: dict - Copy of the validated_data to use for update - kwargs: dict - Object selection arguments (NOTE: the selection arguments must match a unique - object in DB otherwise an IntegrityError will be raised) - Returns: tuple - Asset instance and True if created otherwise false - """ - asset, created = Asset.objects.update_or_create(**look_up, defaults=validated_data) - return asset, created - - def validate_type(self, value): - ''' Validates the the field "type" - ''' - return normalize_and_validate_media_type(value) - - def validate(self, attrs): - name = attrs['name'] if not self.partial else attrs.get('name', self.instance.name) - media_type = attrs['media_type'] if not self.partial else attrs.get( - 'media_type', self.instance.media_type - ) - validate_asset_name_with_media_type(name, media_type) - - validate_json_payload(self) - - return attrs - - def get_fields(self): - fields = super().get_fields() - # This is a hack to allow fields with special characters - fields['gsd'] = fields.pop('eo_gsd') - fields['proj:epsg'] = fields.pop('proj_epsg') - fields['geoadmin:variant'] = fields.pop('geoadmin_variant') - fields['geoadmin:lang'] = fields.pop('geoadmin_lang') - fields['file:checksum'] = fields.pop('checksum_multihash') - - # Older versions of the api still use different name - request = self.context.get('request') - if not is_api_version_1(request): - fields['checksum:multihash'] = fields.pop('file:checksum') - fields['eo:gsd'] = fields.pop('gsd') - fields.pop('roles', None) - - return fields - - -class AssetSerializer(AssetBaseSerializer): - '''Asset serializer for the asset views - - This serializer adds the links list attribute. - ''' - - def to_representation(self, instance): - collection = instance.item.collection.name - item = instance.item.name - name = instance.name - request = self.context.get("request") - representation = super().to_representation(instance) - # Add auto links - # We use OrderedDict, although it is not necessary, because the default serializer/model for - # links already uses OrderedDict, this way we keep consistency between auto link and user - # link - representation['links'] = get_relation_links( - request, 'asset-detail', [collection, item, name] - ) - return representation - - def _validate_href_field(self, attrs): - """Only allow the href field if the collection allows for external assets - - Raise an exception, this replicates the previous behaviour when href - was always read_only - """ - # the href field is translated to the file field here - if 'file' in attrs: - if self.collection: - collection = self.collection - else: - raise LookupError("No collection defined.") - - if not collection.allow_external_assets: - logger.info( - 'Attempted external asset upload with no permission', - extra={ - 'collection': self.collection, 'attrs': attrs - } - ) - errors = {'href': _("Found read-only property in payload")} - raise serializers.ValidationError(code="payload", detail=errors) - - try: - validate_href_url(attrs['file'], collection) - except CoreValidationError as e: - errors = {'href': e.message} - raise serializers.ValidationError(code='payload', detail=errors) - - def validate(self, attrs): - self._validate_href_field(attrs) - return super().validate(attrs) - - -class AssetsForItemSerializer(AssetBaseSerializer): - '''Assets serializer for nesting them inside the item - - Assets should be nested inside their item but using a dictionary instead of a list and without - links. - ''' - - class Meta: - model = Asset - list_serializer_class = AssetsDictSerializer - fields = [ - 'id', - 'title', - 'type', - 'href', - 'description', - 'roles', - 'eo_gsd', - 'geoadmin_lang', - 'geoadmin_variant', - 'proj_epsg', - 'checksum_multihash', - 'created', - 'updated' - ] - - -class CollectionSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): - - class Meta: - model = Collection - fields = [ - 'published', - 'stac_version', - 'stac_extensions', - 'id', - 'title', - 'description', - 'summaries', - 'extent', - 'providers', - 'license', - 'created', - 'updated', - 'links', - 'crs', - 'itemType', - 'assets' - ] - # crs not in sample data, but in specs.. - validators = [] # Remove a default "unique together" constraint. - # (see: - # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) - - published = serializers.BooleanField(write_only=True, default=True) - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - id = serializers.CharField( - required=True, max_length=255, source="name", validators=[validate_name] - ) - title = serializers.CharField(required=False, allow_blank=False, default=None, max_length=255) - # Also links are required in the spec, the main links (self, root, items) are automatically - # generated hence here it is set to required=False which allows to add optional links that - # are not generated - links = CollectionLinkSerializer(required=False, many=True) - providers = ProviderSerializer(required=False, many=True) - - # read only fields - crs = serializers.SerializerMethodField(read_only=True) - created = serializers.DateTimeField(read_only=True) - updated = serializers.DateTimeField(read_only=True) - extent = serializers.SerializerMethodField(read_only=True) - summaries = serializers.SerializerMethodField(read_only=True) - stac_extensions = serializers.SerializerMethodField(read_only=True) - stac_version = serializers.SerializerMethodField(read_only=True) - itemType = serializers.ReadOnlyField(default="Feature") # pylint: disable=invalid-name - assets = AssetsForItemSerializer(many=True, read_only=True) - - def get_crs(self, obj): - return ["http://www.opengis.net/def/crs/OGC/1.3/CRS84"] - - def get_stac_extensions(self, obj): - return [] - - def get_stac_version(self, obj): - return get_stac_version(self.context.get('request')) - - def get_summaries(self, obj): - # Older versions of the api still use different name - request = self.context.get('request') - if not is_api_version_1(request): - return { - 'proj:epsg': obj.summaries_proj_epsg or [], - 'eo:gsd': obj.summaries_eo_gsd or [], - 'geoadmin:variant': obj.summaries_geoadmin_variant or [], - 'geoadmin:lang': obj.summaries_geoadmin_lang or [] - } - return { - 'proj:epsg': obj.summaries_proj_epsg or [], - 'gsd': obj.summaries_eo_gsd or [], - 'geoadmin:variant': obj.summaries_geoadmin_variant or [], - 'geoadmin:lang': obj.summaries_geoadmin_lang or [] - } - - def get_extent(self, obj): - start = obj.extent_start_datetime - end = obj.extent_end_datetime - if start is not None: - start = isoformat(start) - if end is not None: - end = isoformat(end) - - bbox = [0, 0, 0, 0] - if obj.extent_geometry is not None: - bbox = list(GEOSGeometry(obj.extent_geometry).extent) - - return { - "spatial": { - "bbox": [bbox] - }, - "temporal": { - "interval": [[start, end]] - }, - } - - def _update_or_create_providers(self, collection, providers_data): - provider_ids = [] - for provider_data in providers_data: - provider, created = Provider.objects.get_or_create( - collection=collection, - name=provider_data["name"], - defaults={ - 'description': provider_data.get('description', None), - 'roles': provider_data.get('roles', None), - 'url': provider_data.get('url', None) - } - ) - logger.debug( - '%s provider %s', - 'created' if created else 'updated', - provider.name, - extra={"provider": provider_data} - ) - provider_ids.append(provider.id) - # the duplicate here is necessary to update the values in - # case the object already exists - provider.description = provider_data.get('description', provider.description) - provider.roles = provider_data.get('roles', provider.roles) - provider.url = provider_data.get('url', provider.url) - provider.full_clean() - provider.save() - - # Delete providers that were not mentioned in the payload anymore - deleted = Provider.objects.filter(collection=collection).exclude(id__in=provider_ids - ).delete() - logger.info( - "deleted %d stale providers for collection %s", - deleted[0], - collection.name, - extra={"collection": collection.name} - ) - - def create(self, validated_data): - """ - Create and return a new `Collection` instance, given the validated data. - """ - providers_data = validated_data.pop('providers', []) - links_data = validated_data.pop('links', []) - collection = validate_uniqueness_and_create(Collection, validated_data) - self._update_or_create_providers(collection=collection, providers_data=providers_data) - update_or_create_links( - instance_type="collection", - model=CollectionLink, - instance=collection, - links_data=links_data - ) - return collection - - def update(self, instance, validated_data): - """ - Update and return an existing `Collection` instance, given the validated data. - """ - providers_data = validated_data.pop('providers', []) - links_data = validated_data.pop('links', []) - self._update_or_create_providers(collection=instance, providers_data=providers_data) - update_or_create_links( - instance_type="collection", - model=CollectionLink, - instance=instance, - links_data=links_data - ) - return super().update(instance, validated_data) - - def update_or_create(self, look_up, validated_data): - """ - Update or create the collection object selected by kwargs and return the instance. - - When no collection object matching the kwargs selection, a new object is created. - - Args: - validated_data: dict - Copy of the validated_data to use for update - kwargs: dict - Object selection arguments (NOTE: the selection arguments must match a unique - object in DB otherwise an IntegrityError will be raised) - - Returns: tuple - Collection instance and True if created otherwise false - """ - providers_data = validated_data.pop('providers', []) - links_data = validated_data.pop('links', []) - collection, created = Collection.objects.update_or_create( - **look_up, defaults=validated_data - ) - self._update_or_create_providers(collection=collection, providers_data=providers_data) - update_or_create_links( - instance_type="collection", - model=CollectionLink, - instance=collection, - links_data=links_data - ) - return collection, created - - def to_representation(self, instance): - name = instance.name - request = self.context.get("request") - representation = super().to_representation(instance) - - # Add hardcoded value Collection to response to conform to stac spec v1. - representation['type'] = "Collection" - - # Remove property on older versions - if not is_api_version_1(request): - del representation['type'] - - # Add auto links - # We use OrderedDict, although it is not necessary, because the default serializer/model for - # links already uses OrderedDict, this way we keep consistency between auto link and user - # link - representation['links'][:0] = get_relation_links(request, 'collection-detail', [name]) - return representation - - def validate(self, attrs): - validate_json_payload(self) - return attrs - - -class ItemSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): - - class Meta: - model = Item - fields = [ - 'id', - 'collection', - 'type', - 'stac_version', - 'geometry', - 'bbox', - 'properties', - 'stac_extensions', - 'links', - 'assets' - ] - validators = [] # Remove a default "unique together" constraint. - # (see: - # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) - - # NOTE: when explicitely declaring fields, we need to add the validation as for the field - # in model ! - id = serializers.CharField( - source='name', required=True, max_length=255, validators=[validate_name] - ) - properties = ItemsPropertiesSerializer(source='*', required=True) - geometry = gis_serializers.GeometryField(required=True) - links = ItemLinkSerializer(required=False, many=True) - # read only fields - type = serializers.SerializerMethodField() - collection = serializers.SlugRelatedField(slug_field='name', read_only=True) - bbox = BboxSerializer(source='*', read_only=True) - assets = AssetsForItemSerializer(many=True, read_only=True) - stac_extensions = serializers.SerializerMethodField() - stac_version = serializers.SerializerMethodField() - - def get_type(self, obj): - return 'Feature' - - def get_stac_extensions(self, obj): - return [] - - def get_stac_version(self, obj): - return get_stac_version(self.context.get('request')) - - def to_representation(self, instance): - collection = instance.collection.name - name = instance.name - request = self.context.get("request") - representation = super().to_representation(instance) - # Add auto links - # We use OrderedDict, although it is not necessary, because the default serializer/model for - # links already uses OrderedDict, this way we keep consistency between auto link and user - # link - representation['links'][:0] = get_relation_links(request, 'item-detail', [collection, name]) - representation['stac_extensions'] = [ - # Extension provides schema for the 'expires' timestamp - "https://stac-extensions.github.io/timestamps/v1.1.0/schema.json" - ] - return representation - - def create(self, validated_data): - links_data = validated_data.pop('links', []) - item = validate_uniqueness_and_create(Item, validated_data) - update_or_create_links( - instance_type="item", model=ItemLink, instance=item, links_data=links_data - ) - return item - - def update(self, instance, validated_data): - links_data = validated_data.pop('links', []) - update_or_create_links( - instance_type="item", model=ItemLink, instance=instance, links_data=links_data - ) - return super().update(instance, validated_data) - - def update_or_create(self, look_up, validated_data): - """ - Update or create the item object selected by kwargs and return the instance. - When no item object matching the kwargs selection, a new item is created. - Args: - validated_data: dict - Copy of the validated_data to use for update - kwargs: dict - Object selection arguments (NOTE: the selection arguments must match a unique - object in DB otherwise an IntegrityError will be raised) - Returns: tuple - Item instance and True if created otherwise false - """ - links_data = validated_data.pop('links', []) - item, created = Item.objects.update_or_create(**look_up, defaults=validated_data) - update_or_create_links( - instance_type="item", model=ItemLink, instance=item, links_data=links_data - ) - return item, created - - def validate(self, attrs): - if ( - not self.partial or \ - 'properties_datetime' in attrs or \ - 'properties_start_datetime' in attrs or \ - 'properties_end_datetime' in attrs or \ - 'properties_expires' in attrs - ): - validate_item_properties_datetimes( - attrs.get('properties_datetime', None), - attrs.get('properties_start_datetime', None), - attrs.get('properties_end_datetime', None), - attrs.get('properties_expires', None) - ) - else: - logger.info( - 'Skip validation of item properties datetimes; partial update without datetimes' - ) - - validate_json_payload(self) - - return attrs - - -class AssetUploadListSerializer(serializers.ListSerializer): - # pylint: disable=abstract-method - - def to_representation(self, data): - return {'uploads': super().to_representation(data)} - - @property - def data(self): - ret = super(serializers.ListSerializer, self).data - return ReturnDict(ret, serializer=self) - - -class UploadPartSerializer(serializers.Serializer): - '''This serializer is used to serialize the data from/to the S3 API. - ''' - # pylint: disable=abstract-method - etag = serializers.CharField(source='ETag', allow_blank=False, required=True) - part_number = serializers.IntegerField( - source='PartNumber', min_value=1, max_value=100, required=True, allow_null=False - ) - modified = serializers.DateTimeField(source='LastModified', required=False, allow_null=True) - size = serializers.IntegerField(source='Size', allow_null=True, required=False) - - -class AssetUploadSerializer(NonNullModelSerializer): - - class Meta: - model = AssetUpload - list_serializer_class = AssetUploadListSerializer - fields = [ - 'upload_id', - 'status', - 'created', - 'checksum_multihash', - 'completed', - 'aborted', - 'number_parts', - 'md5_parts', - 'urls', - 'ended', - 'parts', - 'update_interval', - 'content_encoding' - ] - - checksum_multihash = serializers.CharField( - source='checksum_multihash', - max_length=255, - required=True, - allow_blank=False, - validators=[validate_checksum_multihash_sha256] - ) - md5_parts = serializers.JSONField(required=True) - update_interval = serializers.IntegerField( - required=False, allow_null=False, min_value=-1, max_value=3600, default=-1 - ) - content_encoding = serializers.CharField( - required=False, - allow_null=False, - allow_blank=False, - min_length=1, - max_length=32, - default='', - validators=[validate_content_encoding] - ) - - # write only fields - ended = serializers.DateTimeField(write_only=True, required=False) - parts = serializers.ListField( - child=UploadPartSerializer(), write_only=True, allow_empty=False, required=False - ) - - # Read only fields - upload_id = serializers.CharField(read_only=True) - created = serializers.DateTimeField(read_only=True) - urls = serializers.JSONField(read_only=True) - completed = serializers.SerializerMethodField() - aborted = serializers.SerializerMethodField() - - def validate(self, attrs): - # get partial from kwargs (if partial true and no md5 : ok, if false no md5 : error) - # Check the md5 parts length - if attrs.get('md5_parts') is not None: - validate_md5_parts(attrs['md5_parts'], attrs['number_parts']) - elif not self.partial: - raise serializers.ValidationError( - detail={'md5_parts': _('md5_parts parameter is missing')}, code='missing' - ) - return attrs - - def get_completed(self, obj): - if obj.status == AssetUpload.Status.COMPLETED: - return isoformat(obj.ended) - return None - - def get_aborted(self, obj): - if obj.status == AssetUpload.Status.ABORTED: - return isoformat(obj.ended) - return None - - def get_fields(self): - fields = super().get_fields() - # This is a hack to allow fields with special characters - fields['file:checksum'] = fields.pop('checksum_multihash') - - # Older versions of the api still use different name - request = self.context.get('request') - if not is_api_version_1(request): - fields['checksum:multihash'] = fields.pop('file:checksum') - return fields - - -class AssetUploadPartsSerializer(serializers.Serializer): - '''S3 list_parts response serializer''' - - # pylint: disable=abstract-method - - class Meta: - list_serializer_class = AssetUploadListSerializer - - # Read only fields - parts = serializers.ListField( - source='Parts', child=UploadPartSerializer(), default=list, read_only=True - ) diff --git a/app/stac_api/serializers/__init__.py b/app/stac_api/serializers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/stac_api/serializers/collection.py b/app/stac_api/serializers/collection.py new file mode 100644 index 00000000..655d6848 --- /dev/null +++ b/app/stac_api/serializers/collection.py @@ -0,0 +1,413 @@ +import logging + +from django.contrib.gis.geos import GEOSGeometry + +from rest_framework import serializers + +from stac_api.models import Collection +from stac_api.models import CollectionAsset +from stac_api.models import CollectionLink +from stac_api.models import Provider +from stac_api.serializers.utils import AssetsDictSerializer +from stac_api.serializers.utils import HrefField +from stac_api.serializers.utils import NonNullModelSerializer +from stac_api.serializers.utils import UpsertModelSerializerMixin +from stac_api.serializers.utils import get_relation_links +from stac_api.serializers.utils import update_or_create_links +from stac_api.utils import get_stac_version +from stac_api.utils import is_api_version_1 +from stac_api.utils import isoformat +from stac_api.validators import normalize_and_validate_media_type +from stac_api.validators import validate_asset_name +from stac_api.validators import validate_asset_name_with_media_type +from stac_api.validators import validate_name +from stac_api.validators_serializer import validate_json_payload +from stac_api.validators_serializer import validate_uniqueness_and_create + +logger = logging.getLogger(__name__) + + +class ProviderSerializer(NonNullModelSerializer): + + class Meta: + model = Provider + fields = ['name', 'roles', 'url', 'description'] + + +class CollectionLinkSerializer(NonNullModelSerializer): + + class Meta: + model = CollectionLink + fields = ['href', 'rel', 'title', 'type'] + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + type = serializers.CharField( + required=False, allow_blank=True, max_length=150, source="link_type" + ) + + +class CollectionAssetBaseSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): + '''Collection asset serializer base class + ''' + + class Meta: + model = CollectionAsset + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'roles', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated', + ] + validators = [] # Remove a default "unique together" constraint. + # (see: + # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField(source='name', max_length=255, validators=[validate_asset_name]) + title = serializers.CharField( + required=False, max_length=255, allow_null=True, allow_blank=False + ) + description = serializers.CharField(required=False, allow_blank=False, allow_null=True) + # Can't be a ChoiceField, as the validate method normalizes the MIME string only after it + # is read. Consistency is nevertheless guaranteed by the validate() and validate_type() methods. + type = serializers.CharField( + source='media_type', required=True, allow_null=False, allow_blank=False + ) + # Here we need to explicitely define these fields with the source, because they are renamed + # in the get_fields() method + proj_epsg = serializers.IntegerField(source='proj_epsg', allow_null=True, required=False) + # read only fields + checksum_multihash = serializers.CharField(source='checksum_multihash', read_only=True) + href = HrefField(source='file', read_only=True) + created = serializers.DateTimeField(read_only=True) + updated = serializers.DateTimeField(read_only=True) + + # helper variable to provide the collection for upsert validation + # see views.AssetDetail.perform_upsert + collection = None + + def create(self, validated_data): + asset = validate_uniqueness_and_create(CollectionAsset, validated_data) + return asset + + def update_or_create(self, look_up, validated_data): + """ + Update or create the asset object selected by kwargs and return the instance. + When no asset object matching the kwargs selection, a new asset is created. + Args: + validated_data: dict + Copy of the validated_data to use for update + kwargs: dict + Object selection arguments (NOTE: the selection arguments must match a unique + object in DB otherwise an IntegrityError will be raised) + Returns: tuple + Asset instance and True if created otherwise false + """ + asset, created = CollectionAsset.objects.update_or_create( + **look_up, + defaults=validated_data, + ) + return asset, created + + def validate_type(self, value): + ''' Validates the the field "type" + ''' + return normalize_and_validate_media_type(value) + + def validate(self, attrs): + name = attrs['name'] if not self.partial else attrs.get('name', self.instance.name) + media_type = attrs['media_type'] if not self.partial else attrs.get( + 'media_type', self.instance.media_type + ) + validate_asset_name_with_media_type(name, media_type) + + validate_json_payload(self) + + return attrs + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['proj:epsg'] = fields.pop('proj_epsg') + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + fields.pop('roles', None) + + return fields + + +class CollectionAssetSerializer(CollectionAssetBaseSerializer): + '''Collection Asset serializer for the collection asset views + + This serializer adds the links list attribute. + ''' + + def to_representation(self, instance): + collection = instance.collection.name + name = instance.name + request = self.context.get("request") + representation = super().to_representation(instance) + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'] = get_relation_links( + request, 'collection-asset-detail', [collection, name] + ) + return representation + + +class CollectionAssetsForCollectionSerializer(CollectionAssetBaseSerializer): + '''Collection assets serializer for nesting them inside the collection + + Assets should be nested inside their collection but using a dictionary instead of a list and + without links. + ''' + + class Meta: + model = CollectionAsset + list_serializer_class = AssetsDictSerializer + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'roles', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated' + ] + + +class CollectionSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): + + class Meta: + model = Collection + fields = [ + 'published', + 'stac_version', + 'stac_extensions', + 'id', + 'title', + 'description', + 'summaries', + 'extent', + 'providers', + 'license', + 'created', + 'updated', + 'links', + 'crs', + 'itemType', + 'assets' + ] + # crs not in sample data, but in specs.. + validators = [] # Remove a default "unique together" constraint. + # (see: + # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) + + published = serializers.BooleanField(write_only=True, default=True) + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField( + required=True, max_length=255, source="name", validators=[validate_name] + ) + title = serializers.CharField(required=False, allow_blank=False, default=None, max_length=255) + # Also links are required in the spec, the main links (self, root, items) are automatically + # generated hence here it is set to required=False which allows to add optional links that + # are not generated + links = CollectionLinkSerializer(required=False, many=True) + providers = ProviderSerializer(required=False, many=True) + + # read only fields + crs = serializers.SerializerMethodField(read_only=True) + created = serializers.DateTimeField(read_only=True) + updated = serializers.DateTimeField(read_only=True) + extent = serializers.SerializerMethodField(read_only=True) + summaries = serializers.SerializerMethodField(read_only=True) + stac_extensions = serializers.SerializerMethodField(read_only=True) + stac_version = serializers.SerializerMethodField(read_only=True) + itemType = serializers.ReadOnlyField(default="Feature") # pylint: disable=invalid-name + assets = CollectionAssetsForCollectionSerializer(many=True, read_only=True) + + def get_crs(self, obj): + return ["http://www.opengis.net/def/crs/OGC/1.3/CRS84"] + + def get_stac_extensions(self, obj): + return [] + + def get_stac_version(self, obj): + return get_stac_version(self.context.get('request')) + + def get_summaries(self, obj): + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + return { + 'proj:epsg': obj.summaries_proj_epsg or [], + 'eo:gsd': obj.summaries_eo_gsd or [], + 'geoadmin:variant': obj.summaries_geoadmin_variant or [], + 'geoadmin:lang': obj.summaries_geoadmin_lang or [] + } + return { + 'proj:epsg': obj.summaries_proj_epsg or [], + 'gsd': obj.summaries_eo_gsd or [], + 'geoadmin:variant': obj.summaries_geoadmin_variant or [], + 'geoadmin:lang': obj.summaries_geoadmin_lang or [] + } + + def get_extent(self, obj): + start = obj.extent_start_datetime + end = obj.extent_end_datetime + if start is not None: + start = isoformat(start) + if end is not None: + end = isoformat(end) + + bbox = [0, 0, 0, 0] + if obj.extent_geometry is not None: + bbox = list(GEOSGeometry(obj.extent_geometry).extent) + + return { + "spatial": { + "bbox": [bbox] + }, + "temporal": { + "interval": [[start, end]] + }, + } + + def _update_or_create_providers(self, collection, providers_data): + provider_ids = [] + for provider_data in providers_data: + provider, created = Provider.objects.get_or_create( + collection=collection, + name=provider_data["name"], + defaults={ + 'description': provider_data.get('description', None), + 'roles': provider_data.get('roles', None), + 'url': provider_data.get('url', None) + } + ) + logger.debug( + '%s provider %s', + 'created' if created else 'updated', + provider.name, + extra={"provider": provider_data} + ) + provider_ids.append(provider.id) + # the duplicate here is necessary to update the values in + # case the object already exists + provider.description = provider_data.get('description', provider.description) + provider.roles = provider_data.get('roles', provider.roles) + provider.url = provider_data.get('url', provider.url) + provider.full_clean() + provider.save() + + # Delete providers that were not mentioned in the payload anymore + deleted = Provider.objects.filter(collection=collection).exclude(id__in=provider_ids + ).delete() + logger.info( + "deleted %d stale providers for collection %s", + deleted[0], + collection.name, + extra={"collection": collection.name} + ) + + def create(self, validated_data): + """ + Create and return a new `Collection` instance, given the validated data. + """ + providers_data = validated_data.pop('providers', []) + links_data = validated_data.pop('links', []) + collection = validate_uniqueness_and_create(Collection, validated_data) + self._update_or_create_providers(collection=collection, providers_data=providers_data) + update_or_create_links( + instance_type="collection", + model=CollectionLink, + instance=collection, + links_data=links_data + ) + return collection + + def update(self, instance, validated_data): + """ + Update and return an existing `Collection` instance, given the validated data. + """ + providers_data = validated_data.pop('providers', []) + links_data = validated_data.pop('links', []) + self._update_or_create_providers(collection=instance, providers_data=providers_data) + update_or_create_links( + instance_type="collection", + model=CollectionLink, + instance=instance, + links_data=links_data + ) + return super().update(instance, validated_data) + + def update_or_create(self, look_up, validated_data): + """ + Update or create the collection object selected by kwargs and return the instance. + + When no collection object matching the kwargs selection, a new object is created. + + Args: + validated_data: dict + Copy of the validated_data to use for update + kwargs: dict + Object selection arguments (NOTE: the selection arguments must match a unique + object in DB otherwise an IntegrityError will be raised) + + Returns: tuple + Collection instance and True if created otherwise false + """ + providers_data = validated_data.pop('providers', []) + links_data = validated_data.pop('links', []) + collection, created = Collection.objects.update_or_create( + **look_up, defaults=validated_data + ) + self._update_or_create_providers(collection=collection, providers_data=providers_data) + update_or_create_links( + instance_type="collection", + model=CollectionLink, + instance=collection, + links_data=links_data + ) + return collection, created + + def to_representation(self, instance): + name = instance.name + request = self.context.get("request") + representation = super().to_representation(instance) + + # Add hardcoded value Collection to response to conform to stac spec v1. + representation['type'] = "Collection" + + # Remove property on older versions + if not is_api_version_1(request): + del representation['type'] + + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'][:0] = get_relation_links(request, 'collection-detail', [name]) + return representation + + def validate(self, attrs): + validate_json_payload(self) + return attrs diff --git a/app/stac_api/serializers/general.py b/app/stac_api/serializers/general.py new file mode 100644 index 00000000..4d1aba7e --- /dev/null +++ b/app/stac_api/serializers/general.py @@ -0,0 +1,130 @@ +import logging +from collections import OrderedDict +from urllib.parse import urlparse + +from django.conf import settings +from django.utils.translation import gettext_lazy as _ + +from rest_framework import serializers +from rest_framework.validators import UniqueValidator + +from stac_api.models import LandingPage +from stac_api.models import LandingPageLink +from stac_api.utils import get_browser_url +from stac_api.utils import get_stac_version +from stac_api.utils import get_url +from stac_api.utils import is_api_version_1 +from stac_api.validators import validate_name + +logger = logging.getLogger(__name__) + + +class LandingPageLinkSerializer(serializers.ModelSerializer): + + class Meta: + model = LandingPageLink + fields = ['href', 'rel', 'link_type', 'title'] + + +class ConformancePageSerializer(serializers.ModelSerializer): + + class Meta: + model = LandingPage + fields = ['conformsTo'] + + +class LandingPageSerializer(serializers.ModelSerializer): + + class Meta: + model = LandingPage + fields = ['id', 'title', 'description', 'links', 'stac_version', 'conformsTo'] + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField( + max_length=255, + source="name", + validators=[validate_name, UniqueValidator(queryset=LandingPage.objects.all())] + ) + # Read only fields + links = LandingPageLinkSerializer(many=True, read_only=True) + stac_version = serializers.SerializerMethodField() + + def get_stac_version(self, obj): + return get_stac_version(self.context.get('request')) + + def to_representation(self, instance): + representation = super().to_representation(instance) + request = self.context.get("request") + + # Add hardcoded value Catalog to response to conform to stac spec v1. + representation['type'] = "Catalog" + + # Remove property on older versions + if not is_api_version_1(request): + del representation['type'] + + version = request.resolver_match.namespace + spec_base = f'{urlparse(settings.STATIC_SPEC_URL).path.strip(' / ')}/{version}' + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'][:0] = [ + OrderedDict([ + ('rel', 'root'), + ('href', get_url(request, 'landing-page')), + ("type", "application/json"), + ]), + OrderedDict([ + ('rel', 'self'), + ('href', get_url(request, 'landing-page')), + ("type", "application/json"), + ("title", "This document"), + ]), + OrderedDict([ + ("rel", "service-doc"), + ("href", request.build_absolute_uri(f"/{spec_base}/api.html")), + ("type", "text/html"), + ("title", "The API documentation"), + ]), + OrderedDict([ + ("rel", "service-desc"), + ("href", request.build_absolute_uri(f"/{spec_base}/openapi.yaml")), + ("type", "application/vnd.oai.openapi+yaml;version=3.0"), + ("title", "The OPENAPI description of the service"), + ]), + OrderedDict([ + ("rel", "conformance"), + ("href", get_url(request, 'conformance')), + ("type", "application/json"), + ("title", "OGC API conformance classes implemented by this server"), + ]), + OrderedDict([ + ('rel', 'data'), + ('href', get_url(request, 'collections-list')), + ("type", "application/json"), + ("title", "Information about the feature collections"), + ]), + OrderedDict([ + ("href", get_url(request, 'search-list')), + ("rel", "search"), + ("method", "GET"), + ("type", "application/json"), + ("title", "Search across feature collections"), + ]), + OrderedDict([ + ("href", get_url(request, 'search-list')), + ("rel", "search"), + ("method", "POST"), + ("type", "application/json"), + ("title", "Search across feature collections"), + ]), + OrderedDict([ + ("href", get_browser_url(request, 'browser-catalog')), + ("rel", "alternate"), + ("type", "text/html"), + ("title", "STAC Browser"), + ]), + ] + return representation diff --git a/app/stac_api/serializers/item.py b/app/stac_api/serializers/item.py new file mode 100644 index 00000000..b0cb9fe2 --- /dev/null +++ b/app/stac_api/serializers/item.py @@ -0,0 +1,406 @@ +import logging + +from django.core.exceptions import ValidationError as CoreValidationError +from django.utils.translation import gettext_lazy as _ + +from rest_framework import serializers +from rest_framework_gis import serializers as gis_serializers + +from stac_api.models import Asset +from stac_api.models import Item +from stac_api.models import ItemLink +from stac_api.serializers.utils import AssetsDictSerializer +from stac_api.serializers.utils import HrefField +from stac_api.serializers.utils import NonNullModelSerializer +from stac_api.serializers.utils import UpsertModelSerializerMixin +from stac_api.serializers.utils import get_relation_links +from stac_api.serializers.utils import update_or_create_links +from stac_api.utils import get_stac_version +from stac_api.utils import is_api_version_1 +from stac_api.validators import normalize_and_validate_media_type +from stac_api.validators import validate_asset_name +from stac_api.validators import validate_asset_name_with_media_type +from stac_api.validators import validate_geoadmin_variant +from stac_api.validators import validate_href_url +from stac_api.validators import validate_item_properties_datetimes +from stac_api.validators import validate_name +from stac_api.validators_serializer import validate_json_payload +from stac_api.validators_serializer import validate_uniqueness_and_create + +logger = logging.getLogger(__name__) + + +class BboxSerializer(gis_serializers.GeoFeatureModelSerializer): + + class Meta: + model = Item + geo_field = "geometry" + auto_bbox = True + fields = ['geometry'] + + def to_representation(self, instance): + python_native = super().to_representation(instance) + return python_native['bbox'] + + +class ItemLinkSerializer(NonNullModelSerializer): + + class Meta: + model = ItemLink + fields = ['href', 'rel', 'title', 'type'] + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + type = serializers.CharField( + required=False, allow_blank=True, max_length=255, source="link_type" + ) + + +class ItemsPropertiesSerializer(serializers.Serializer): + # pylint: disable=abstract-method + # ItemsPropertiesSerializer is a nested serializer and don't directly create/write instances + # therefore we don't need to implement the super method create() and update() + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + datetime = serializers.DateTimeField(source='properties_datetime', required=False, default=None) + start_datetime = serializers.DateTimeField( + source='properties_start_datetime', required=False, default=None + ) + end_datetime = serializers.DateTimeField( + source='properties_end_datetime', required=False, default=None + ) + title = serializers.CharField( + source='properties_title', + required=False, + allow_blank=False, + allow_null=True, + max_length=255, + default=None + ) + created = serializers.DateTimeField(read_only=True) + updated = serializers.DateTimeField(read_only=True) + expires = serializers.DateTimeField(source='properties_expires', required=False, default=None) + + +class AssetBaseSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): + '''Asset serializer base class + ''' + + class Meta: + model = Asset + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'eo_gsd', + 'roles', + 'geoadmin_lang', + 'geoadmin_variant', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated', + ] + validators = [] # Remove a default "unique together" constraint. + # (see: + # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField(source='name', max_length=255, validators=[validate_asset_name]) + title = serializers.CharField( + required=False, max_length=255, allow_null=True, allow_blank=False + ) + description = serializers.CharField(required=False, allow_blank=False, allow_null=True) + # Can't be a ChoiceField, as the validate method normalizes the MIME string only after it + # is read. Consistency is nevertheless guaranteed by the validate() and validate_type() methods. + type = serializers.CharField( + source='media_type', required=True, allow_null=False, allow_blank=False + ) + # Here we need to explicitely define these fields with the source, because they are renamed + # in the get_fields() method + eo_gsd = serializers.FloatField(source='eo_gsd', required=False, allow_null=True) + geoadmin_lang = serializers.ChoiceField( + source='geoadmin_lang', + choices=Asset.Language.values, + required=False, + allow_null=True, + allow_blank=False + ) + geoadmin_variant = serializers.CharField( + source='geoadmin_variant', + max_length=25, + allow_blank=False, + allow_null=True, + required=False, + validators=[validate_geoadmin_variant] + ) + proj_epsg = serializers.IntegerField(source='proj_epsg', allow_null=True, required=False) + # read only fields + checksum_multihash = serializers.CharField(source='checksum_multihash', read_only=True) + href = HrefField(source='file', required=False) + created = serializers.DateTimeField(read_only=True) + updated = serializers.DateTimeField(read_only=True) + + # helper variable to provide the collection for upsert validation + # see views.AssetDetail.perform_upsert + collection = None + + def create(self, validated_data): + asset = validate_uniqueness_and_create(Asset, validated_data) + return asset + + def update_or_create(self, look_up, validated_data): + """ + Update or create the asset object selected by kwargs and return the instance. + When no asset object matching the kwargs selection, a new asset is created. + Args: + validated_data: dict + Copy of the validated_data to use for update + kwargs: dict + Object selection arguments (NOTE: the selection arguments must match a unique + object in DB otherwise an IntegrityError will be raised) + Returns: tuple + Asset instance and True if created otherwise false + """ + asset, created = Asset.objects.update_or_create(**look_up, defaults=validated_data) + return asset, created + + def validate_type(self, value): + ''' Validates the the field "type" + ''' + return normalize_and_validate_media_type(value) + + def validate(self, attrs): + name = attrs['name'] if not self.partial else attrs.get('name', self.instance.name) + media_type = attrs['media_type'] if not self.partial else attrs.get( + 'media_type', self.instance.media_type + ) + validate_asset_name_with_media_type(name, media_type) + + validate_json_payload(self) + + return attrs + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['gsd'] = fields.pop('eo_gsd') + fields['proj:epsg'] = fields.pop('proj_epsg') + fields['geoadmin:variant'] = fields.pop('geoadmin_variant') + fields['geoadmin:lang'] = fields.pop('geoadmin_lang') + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + fields['eo:gsd'] = fields.pop('gsd') + fields.pop('roles', None) + + return fields + + +class AssetSerializer(AssetBaseSerializer): + '''Asset serializer for the asset views + + This serializer adds the links list attribute. + ''' + + def to_representation(self, instance): + collection = instance.item.collection.name + item = instance.item.name + name = instance.name + request = self.context.get("request") + representation = super().to_representation(instance) + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'] = get_relation_links( + request, 'asset-detail', [collection, item, name] + ) + return representation + + def _validate_href_field(self, attrs): + """Only allow the href field if the collection allows for external assets + + Raise an exception, this replicates the previous behaviour when href + was always read_only + """ + # the href field is translated to the file field here + if 'file' in attrs: + if self.collection: + collection = self.collection + else: + raise LookupError("No collection defined.") + + if not collection.allow_external_assets: + logger.info( + 'Attempted external asset upload with no permission', + extra={ + 'collection': self.collection, 'attrs': attrs + } + ) + errors = {'href': _("Found read-only property in payload")} + raise serializers.ValidationError(code="payload", detail=errors) + + try: + validate_href_url(attrs['file'], collection) + except CoreValidationError as e: + errors = {'href': e.message} + raise serializers.ValidationError(code='payload', detail=errors) + + def validate(self, attrs): + self._validate_href_field(attrs) + return super().validate(attrs) + + +class AssetsForItemSerializer(AssetBaseSerializer): + '''Assets serializer for nesting them inside the item + + Assets should be nested inside their item but using a dictionary instead of a list and without + links. + ''' + + class Meta: + model = Asset + list_serializer_class = AssetsDictSerializer + fields = [ + 'id', + 'title', + 'type', + 'href', + 'description', + 'roles', + 'eo_gsd', + 'geoadmin_lang', + 'geoadmin_variant', + 'proj_epsg', + 'checksum_multihash', + 'created', + 'updated' + ] + + +class ItemSerializer(NonNullModelSerializer, UpsertModelSerializerMixin): + + class Meta: + model = Item + fields = [ + 'id', + 'collection', + 'type', + 'stac_version', + 'geometry', + 'bbox', + 'properties', + 'stac_extensions', + 'links', + 'assets' + ] + validators = [] # Remove a default "unique together" constraint. + # (see: + # https://www.django-rest-framework.org/api-guide/validators/#limitations-of-validators) + + # NOTE: when explicitely declaring fields, we need to add the validation as for the field + # in model ! + id = serializers.CharField( + source='name', required=True, max_length=255, validators=[validate_name] + ) + properties = ItemsPropertiesSerializer(source='*', required=True) + geometry = gis_serializers.GeometryField(required=True) + links = ItemLinkSerializer(required=False, many=True) + # read only fields + type = serializers.SerializerMethodField() + collection = serializers.SlugRelatedField(slug_field='name', read_only=True) + bbox = BboxSerializer(source='*', read_only=True) + assets = AssetsForItemSerializer(many=True, read_only=True) + stac_extensions = serializers.SerializerMethodField() + stac_version = serializers.SerializerMethodField() + + def get_type(self, obj): + return 'Feature' + + def get_stac_extensions(self, obj): + return [] + + def get_stac_version(self, obj): + return get_stac_version(self.context.get('request')) + + def to_representation(self, instance): + collection = instance.collection.name + name = instance.name + request = self.context.get("request") + representation = super().to_representation(instance) + # Add auto links + # We use OrderedDict, although it is not necessary, because the default serializer/model for + # links already uses OrderedDict, this way we keep consistency between auto link and user + # link + representation['links'][:0] = get_relation_links(request, 'item-detail', [collection, name]) + representation['stac_extensions'] = [ + # Extension provides schema for the 'expires' timestamp + "https://stac-extensions.github.io/timestamps/v1.1.0/schema.json" + ] + return representation + + def create(self, validated_data): + links_data = validated_data.pop('links', []) + item = validate_uniqueness_and_create(Item, validated_data) + update_or_create_links( + instance_type="item", model=ItemLink, instance=item, links_data=links_data + ) + return item + + def update(self, instance, validated_data): + links_data = validated_data.pop('links', []) + update_or_create_links( + instance_type="item", model=ItemLink, instance=instance, links_data=links_data + ) + return super().update(instance, validated_data) + + def update_or_create(self, look_up, validated_data): + """ + Update or create the item object selected by kwargs and return the instance. + When no item object matching the kwargs selection, a new item is created. + Args: + validated_data: dict + Copy of the validated_data to use for update + kwargs: dict + Object selection arguments (NOTE: the selection arguments must match a unique + object in DB otherwise an IntegrityError will be raised) + Returns: tuple + Item instance and True if created otherwise false + """ + links_data = validated_data.pop('links', []) + item, created = Item.objects.update_or_create(**look_up, defaults=validated_data) + update_or_create_links( + instance_type="item", model=ItemLink, instance=item, links_data=links_data + ) + return item, created + + def validate(self, attrs): + if ( + not self.partial or \ + 'properties_datetime' in attrs or \ + 'properties_start_datetime' in attrs or \ + 'properties_end_datetime' in attrs or \ + 'properties_expires' in attrs + ): + validate_item_properties_datetimes( + attrs.get('properties_datetime', None), + attrs.get('properties_start_datetime', None), + attrs.get('properties_end_datetime', None), + attrs.get('properties_expires', None) + ) + else: + logger.info( + 'Skip validation of item properties datetimes; partial update without datetimes' + ) + + validate_json_payload(self) + + return attrs diff --git a/app/stac_api/serializers/upload.py b/app/stac_api/serializers/upload.py new file mode 100644 index 00000000..9a42887c --- /dev/null +++ b/app/stac_api/serializers/upload.py @@ -0,0 +1,231 @@ +import logging + +from django.utils.translation import gettext_lazy as _ + +from rest_framework import serializers +from rest_framework.utils.serializer_helpers import ReturnDict + +from stac_api.models import AssetUpload +from stac_api.models import CollectionAssetUpload +from stac_api.serializers.utils import NonNullModelSerializer +from stac_api.utils import is_api_version_1 +from stac_api.utils import isoformat +from stac_api.validators import validate_checksum_multihash_sha256 +from stac_api.validators import validate_content_encoding +from stac_api.validators import validate_md5_parts + +logger = logging.getLogger(__name__) + + +class AssetUploadListSerializer(serializers.ListSerializer): + # pylint: disable=abstract-method + + def to_representation(self, data): + return {'uploads': super().to_representation(data)} + + @property + def data(self): + ret = super(serializers.ListSerializer, self).data + return ReturnDict(ret, serializer=self) + + +class UploadPartSerializer(serializers.Serializer): + '''This serializer is used to serialize the data from/to the S3 API. + ''' + # pylint: disable=abstract-method + etag = serializers.CharField(source='ETag', allow_blank=False, required=True) + part_number = serializers.IntegerField( + source='PartNumber', min_value=1, max_value=100, required=True, allow_null=False + ) + modified = serializers.DateTimeField(source='LastModified', required=False, allow_null=True) + size = serializers.IntegerField(source='Size', allow_null=True, required=False) + + +class AssetUploadSerializer(NonNullModelSerializer): + + class Meta: + model = AssetUpload + list_serializer_class = AssetUploadListSerializer + fields = [ + 'upload_id', + 'status', + 'created', + 'checksum_multihash', + 'completed', + 'aborted', + 'number_parts', + 'md5_parts', + 'urls', + 'ended', + 'parts', + 'update_interval', + 'content_encoding' + ] + + checksum_multihash = serializers.CharField( + source='checksum_multihash', + max_length=255, + required=True, + allow_blank=False, + validators=[validate_checksum_multihash_sha256] + ) + md5_parts = serializers.JSONField(required=True) + update_interval = serializers.IntegerField( + required=False, allow_null=False, min_value=-1, max_value=3600, default=-1 + ) + content_encoding = serializers.CharField( + required=False, + allow_null=False, + allow_blank=False, + min_length=1, + max_length=32, + default='', + validators=[validate_content_encoding] + ) + + # write only fields + ended = serializers.DateTimeField(write_only=True, required=False) + parts = serializers.ListField( + child=UploadPartSerializer(), write_only=True, allow_empty=False, required=False + ) + + # Read only fields + upload_id = serializers.CharField(read_only=True) + created = serializers.DateTimeField(read_only=True) + urls = serializers.JSONField(read_only=True) + completed = serializers.SerializerMethodField() + aborted = serializers.SerializerMethodField() + + def validate(self, attrs): + # get partial from kwargs (if partial true and no md5 : ok, if false no md5 : error) + # Check the md5 parts length + if attrs.get('md5_parts') is not None: + validate_md5_parts(attrs['md5_parts'], attrs['number_parts']) + elif not self.partial: + raise serializers.ValidationError( + detail={'md5_parts': _('md5_parts parameter is missing')}, code='missing' + ) + return attrs + + def get_completed(self, obj): + if obj.status == AssetUpload.Status.COMPLETED: + return isoformat(obj.ended) + return None + + def get_aborted(self, obj): + if obj.status == AssetUpload.Status.ABORTED: + return isoformat(obj.ended) + return None + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + return fields + + +class AssetUploadPartsSerializer(serializers.Serializer): + '''S3 list_parts response serializer''' + + # pylint: disable=abstract-method + + class Meta: + list_serializer_class = AssetUploadListSerializer + + # Read only fields + parts = serializers.ListField( + source='Parts', child=UploadPartSerializer(), default=list, read_only=True + ) + + +class CollectionAssetUploadSerializer(NonNullModelSerializer): + + class Meta: + model = CollectionAssetUpload + list_serializer_class = AssetUploadListSerializer + fields = [ + 'upload_id', + 'status', + 'created', + 'checksum_multihash', + 'completed', + 'aborted', + 'number_parts', + 'md5_parts', + 'urls', + 'ended', + 'parts', + 'update_interval', + 'content_encoding' + ] + + checksum_multihash = serializers.CharField( + source='checksum_multihash', + max_length=255, + required=True, + allow_blank=False, + validators=[validate_checksum_multihash_sha256] + ) + md5_parts = serializers.JSONField(required=True) + update_interval = serializers.IntegerField( + required=False, allow_null=False, min_value=-1, max_value=3600, default=-1 + ) + content_encoding = serializers.CharField( + required=False, + allow_null=False, + allow_blank=False, + min_length=1, + max_length=32, + default='', + validators=[validate_content_encoding] + ) + + # write only fields + ended = serializers.DateTimeField(write_only=True, required=False) + parts = serializers.ListField( + child=UploadPartSerializer(), write_only=True, allow_empty=False, required=False + ) + + # Read only fields + upload_id = serializers.CharField(read_only=True) + created = serializers.DateTimeField(read_only=True) + urls = serializers.JSONField(read_only=True) + completed = serializers.SerializerMethodField() + aborted = serializers.SerializerMethodField() + + def validate(self, attrs): + # get partial from kwargs (if partial true and no md5 : ok, if false no md5 : error) + # Check the md5 parts length + if attrs.get('md5_parts') is not None: + validate_md5_parts(attrs['md5_parts'], attrs['number_parts']) + elif not self.partial: + raise serializers.ValidationError( + detail={'md5_parts': _('md5_parts parameter is missing')}, code='missing' + ) + return attrs + + def get_completed(self, obj): + if obj.status == CollectionAssetUpload.Status.COMPLETED: + return isoformat(obj.ended) + return None + + def get_aborted(self, obj): + if obj.status == CollectionAssetUpload.Status.ABORTED: + return isoformat(obj.ended) + return None + + def get_fields(self): + fields = super().get_fields() + # This is a hack to allow fields with special characters + fields['file:checksum'] = fields.pop('checksum_multihash') + + # Older versions of the api still use different name + request = self.context.get('request') + if not is_api_version_1(request): + fields['checksum:multihash'] = fields.pop('file:checksum') + return fields diff --git a/app/stac_api/serializers_utils.py b/app/stac_api/serializers/utils.py similarity index 90% rename from app/stac_api/serializers_utils.py rename to app/stac_api/serializers/utils.py index 03f950ae..70ae5220 100644 --- a/app/stac_api/serializers_utils.py +++ b/app/stac_api/serializers/utils.py @@ -4,6 +4,7 @@ from rest_framework import serializers from rest_framework.utils.serializer_helpers import ReturnDict +from stac_api.utils import build_asset_href from stac_api.utils import get_browser_url from stac_api.utils import get_url @@ -72,6 +73,14 @@ def update_or_create_links(model, instance, instance_type, links_data): 'parent': 'landing-page', 'browser': 'browser-collection', }, + 'collection-assets-list': { + 'parent': 'collection-detail', + 'browser': None, + }, + 'collection-asset-detail': { + 'parent': 'collection-detail', + 'browser': None, + }, 'items-list': { 'parent': 'collection-detail', 'browser': 'browser-collection', @@ -107,6 +116,8 @@ def get_parent_link(request, view, view_args=()): ''' def parent_args(view, args): + if view.startswith('collection-asset'): + return args[:1] if view.startswith('item'): return args[:1] if view.startswith('asset'): @@ -301,3 +312,30 @@ def to_representation(self, data): def data(self): ret = super(serializers.ListSerializer, self).data return ReturnDict(ret, serializer=self) + + +class AssetsDictSerializer(DictSerializer): + '''Assets serializer list to dictionary + + This serializer returns an asset dictionary with the asset name as keys. + ''' + # pylint: disable=abstract-method + key_identifier = 'id' + + +class HrefField(serializers.Field): + '''Special Href field for Assets''' + + # pylint: disable=abstract-method + + def to_representation(self, value): + # build an absolute URL from the file path + request = self.context.get("request") + path = value.name + + if value.instance.is_external: + return path + return build_asset_href(request, path) + + def to_internal_value(self, data): + return data diff --git a/app/stac_api/signals.py b/app/stac_api/signals.py index 8c8063b5..0410b70f 100644 --- a/app/stac_api/signals.py +++ b/app/stac_api/signals.py @@ -6,6 +6,8 @@ from stac_api.models import Asset from stac_api.models import AssetUpload +from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload logger = logging.getLogger(__name__) @@ -29,6 +31,24 @@ def check_on_going_upload(sender, instance, **kwargs): ) +@receiver(pre_delete, sender=CollectionAssetUpload) +def check_on_going_collection_asset_upload(sender, instance, **kwargs): + if instance.status == CollectionAssetUpload.Status.IN_PROGRESS: + logger.error( + "Cannot delete collection asset %s due to upload %s which is still in progress", + instance.asset.name, + instance.upload_id, + extra={ + 'upload_id': instance.upload_id, + 'asset': instance.asset.name, + 'collection': instance.asset.collection.name + } + ) + raise ProtectedError( + f"Collection Asset {instance.asset.name} has still an upload in progress", [instance] + ) + + @receiver(pre_delete, sender=Asset) def delete_s3_asset(sender, instance, **kwargs): # The file is not automatically deleted by Django @@ -36,3 +56,12 @@ def delete_s3_asset(sender, instance, **kwargs): # hence it has to be done here. logger.info("The asset %s is deleted from s3", instance.file.name) instance.file.delete(save=False) + + +@receiver(pre_delete, sender=CollectionAsset) +def delete_s3_collection_asset(sender, instance, **kwargs): + # The file is not automatically deleted by Django + # when the object holding its reference is deleted + # hence it has to be done here. + logger.info("The collection asset %s is deleted from s3", instance.file.name) + instance.file.delete(save=False) diff --git a/app/stac_api/urls.py b/app/stac_api/urls.py index bd23ec38..5ccaf6cd 100644 --- a/app/stac_api/urls.py +++ b/app/stac_api/urls.py @@ -4,21 +4,28 @@ from rest_framework.authtoken.views import obtain_auth_token -from stac_api.views import AssetDetail -from stac_api.views import AssetsList -from stac_api.views import AssetUploadAbort -from stac_api.views import AssetUploadComplete -from stac_api.views import AssetUploadDetail -from stac_api.views import AssetUploadPartsList -from stac_api.views import AssetUploadsList -from stac_api.views import CollectionDetail -from stac_api.views import CollectionList -from stac_api.views import ConformancePageDetail -from stac_api.views import ItemDetail -from stac_api.views import ItemsList -from stac_api.views import LandingPageDetail -from stac_api.views import SearchList -from stac_api.views import recalculate_extent +from stac_api.views.collection import CollectionAssetDetail +from stac_api.views.collection import CollectionAssetsList +from stac_api.views.collection import CollectionDetail +from stac_api.views.collection import CollectionList +from stac_api.views.general import ConformancePageDetail +from stac_api.views.general import LandingPageDetail +from stac_api.views.general import SearchList +from stac_api.views.general import recalculate_extent +from stac_api.views.item import AssetDetail +from stac_api.views.item import AssetsList +from stac_api.views.item import ItemDetail +from stac_api.views.item import ItemsList +from stac_api.views.upload import AssetUploadAbort +from stac_api.views.upload import AssetUploadComplete +from stac_api.views.upload import AssetUploadDetail +from stac_api.views.upload import AssetUploadPartsList +from stac_api.views.upload import AssetUploadsList +from stac_api.views.upload import CollectionAssetUploadAbort +from stac_api.views.upload import CollectionAssetUploadComplete +from stac_api.views.upload import CollectionAssetUploadDetail +from stac_api.views.upload import CollectionAssetUploadPartsList +from stac_api.views.upload import CollectionAssetUploadsList # HEALTHCHECK_ENDPOINT = settings.HEALTHCHECK_ENDPOINT @@ -41,7 +48,46 @@ path("/assets/", include(asset_urls)) ] +collection_asset_upload_urls = [ + path( + "", CollectionAssetUploadDetail.as_view(), name='collection-asset-upload-detail' + ), + path( + "/parts", + CollectionAssetUploadPartsList.as_view(), + name='collection-asset-upload-parts-list' + ), + path( + "/complete", + CollectionAssetUploadComplete.as_view(), + name='collection-asset-upload-complete' + ), + path( + "/abort", + CollectionAssetUploadAbort.as_view(), + name='collection-asset-upload-abort' + ), +] + +collection_asset_urls = [ + path("", CollectionAssetDetail.as_view(), name='collection-asset-detail'), + path( + "/uploads", + CollectionAssetUploadsList.as_view(), + name='collection-asset-uploads-list' + ), + path("/uploads/", include(collection_asset_upload_urls)) +] + collection_urls = [ + path("", CollectionDetail.as_view(), name='collection-detail'), + path("/items", ItemsList.as_view(), name='items-list'), + path("/items/", include(item_urls)), + path("/assets", CollectionAssetsList.as_view(), name='collection-assets-list'), + path("/assets/", include(collection_asset_urls)) +] + +collection_urls_v09 = [ path("", CollectionDetail.as_view(), name='collection-detail'), path("/items", ItemsList.as_view(), name='items-list'), path("/items/", include(item_urls)) @@ -58,7 +104,7 @@ path("conformance", ConformancePageDetail.as_view(), name='conformance'), path("search", SearchList.as_view(), name='search-list'), path("collections", CollectionList.as_view(), name='collections-list'), - path("collections/", include(collection_urls)), + path("collections/", include(collection_urls_v09)), path("update-extent", recalculate_extent) ], "v0.9"), diff --git a/app/stac_api/validators_view.py b/app/stac_api/validators_view.py index 1200f879..c8bd571d 100644 --- a/app/stac_api/validators_view.py +++ b/app/stac_api/validators_view.py @@ -9,6 +9,7 @@ from stac_api.models import Asset from stac_api.models import Collection +from stac_api.models import CollectionAsset from stac_api.models import Item logger = logging.getLogger(__name__) @@ -82,6 +83,30 @@ def validate_asset(kwargs): ) +def validate_collection_asset(kwargs): + '''Validate that the collection asset given in request kwargs exists + + Args: + kwargs: dict + request kwargs dictionary + + Raises: + Http404: when the asset doesn't exists + ''' + if not CollectionAsset.objects.filter( + name=kwargs['asset_name'], collection__name=kwargs['collection_name'] + ).exists(): + logger.error( + "The asset %s is not part of the collection %s", + kwargs['asset_name'], + kwargs['collection_name'] + ) + raise Http404( + f"The asset {kwargs['asset_name']} is not part of " + f"the collection {kwargs['collection_name']}" + ) + + def validate_renaming(serializer, original_id, id_field='name', extra_log=None): '''Validate that the object name is not different from the one defined in the data. diff --git a/app/stac_api/views.py b/app/stac_api/views.py deleted file mode 100644 index 2d75d099..00000000 --- a/app/stac_api/views.py +++ /dev/null @@ -1,863 +0,0 @@ -import json -import logging -from datetime import datetime -from operator import itemgetter - -from django.conf import settings -from django.db import IntegrityError -from django.db import transaction -from django.db.models import Min -from django.db.models import Prefetch -from django.db.models import Q -from django.utils import timezone -from django.utils.translation import gettext_lazy as _ - -from rest_framework import generics -from rest_framework import mixins -from rest_framework import permissions -from rest_framework import serializers -from rest_framework.decorators import api_view -from rest_framework.decorators import permission_classes -from rest_framework.exceptions import APIException -from rest_framework.generics import get_object_or_404 -from rest_framework.permissions import AllowAny -from rest_framework.response import Response -from rest_framework_condition import etag - -from stac_api import views_mixins -from stac_api.exceptions import UploadInProgressError -from stac_api.exceptions import UploadNotInProgressError -from stac_api.models import Asset -from stac_api.models import AssetUpload -from stac_api.models import Collection -from stac_api.models import Item -from stac_api.models import LandingPage -from stac_api.pagination import ExtApiPagination -from stac_api.pagination import GetPostCursorPagination -from stac_api.s3_multipart_upload import MultipartUpload -from stac_api.serializers import AssetSerializer -from stac_api.serializers import AssetUploadPartsSerializer -from stac_api.serializers import AssetUploadSerializer -from stac_api.serializers import CollectionSerializer -from stac_api.serializers import ConformancePageSerializer -from stac_api.serializers import ItemSerializer -from stac_api.serializers import LandingPageSerializer -from stac_api.serializers_utils import get_relation_links -from stac_api.utils import call_calculate_extent -from stac_api.utils import get_asset_path -from stac_api.utils import harmonize_post_get_for_search -from stac_api.utils import is_api_version_1 -from stac_api.utils import select_s3_bucket -from stac_api.utils import utc_aware -from stac_api.validators_serializer import ValidateSearchRequest -from stac_api.validators_view import validate_asset -from stac_api.validators_view import validate_collection -from stac_api.validators_view import validate_item -from stac_api.validators_view import validate_renaming - -logger = logging.getLogger(__name__) - - -def get_etag(queryset): - if queryset.exists(): - return list(queryset.only('etag').values('etag').first().values())[0] - return None - - -def get_collection_etag(request, *args, **kwargs): - '''Get the ETag for a collection object - - The ETag is an UUID4 computed on each object changes (including relations; provider and links) - ''' - tag = get_etag(Collection.objects.filter(name=kwargs['collection_name'])) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from get_collection_etag():\n%s", - Collection.objects.filter(name=kwargs['collection_name'] - ).explain(verbose=True, analyze=True) - ) - logger.debug( - "The corresponding SQL statement:\n%s", - Collection.objects.filter(name=kwargs['collection_name']).query - ) - - return tag - - -def get_item_etag(request, *args, **kwargs): - '''Get the ETag for a item object - - The ETag is an UUID4 computed on each object changes (including relations; assets and links) - ''' - tag = get_etag( - Item.objects.filter(collection__name=kwargs['collection_name'], name=kwargs['item_name']) - ) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from get_item_etag():\n%s", - Item.objects.filter( - collection__name=kwargs['collection_name'], name=kwargs['item_name'] - ).explain(verbose=True, analyze=True) - ) - logger.debug( - "The corresponding SQL statement:\n%s", - Item.objects.filter( - collection__name=kwargs['collection_name'], name=kwargs['item_name'] - ).query - ) - - return tag - - -def get_asset_etag(request, *args, **kwargs): - '''Get the ETag for a asset object - - The ETag is an UUID4 computed on each object changes - ''' - tag = get_etag( - Asset.objects.filter( - item__collection__name=kwargs['collection_name'], - item__name=kwargs['item_name'], - name=kwargs['asset_name'] - ) - ) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from get_asset_etag():\n%s", - Asset.objects.filter(item__name=kwargs['item_name'], - name=kwargs['asset_name']).explain(verbose=True, analyze=True) - ) - logger.debug( - "The corresponding SQL statement:\n%s", - Asset.objects.filter(item__name=kwargs['item_name'], name=kwargs['asset_name']).query - ) - - return tag - - -def get_asset_upload_etag(request, *args, **kwargs): - '''Get the ETag for an asset upload object - - The ETag is an UUID4 computed on each object changes - ''' - return get_etag( - AssetUpload.objects.filter( - asset__item__collection__name=kwargs['collection_name'], - asset__item__name=kwargs['item_name'], - asset__name=kwargs['asset_name'], - upload_id=kwargs['upload_id'] - ) - ) - - -class LandingPageDetail(generics.RetrieveAPIView): - name = 'landing-page' # this name must match the name in urls.py - serializer_class = LandingPageSerializer - queryset = LandingPage.objects.all() - - def get_object(self): - if not is_api_version_1(self.request): - return LandingPage.objects.get(version='v0.9') - return LandingPage.objects.get(version='v1') - - -class ConformancePageDetail(generics.RetrieveAPIView): - name = 'conformance' # this name must match the name in urls.py - serializer_class = ConformancePageSerializer - queryset = LandingPage.objects.all() - - def get_object(self): - if not is_api_version_1(self.request): - return LandingPage.objects.get(version='v0.9') - return LandingPage.objects.get(version='v1') - - -class SearchList(generics.GenericAPIView, mixins.ListModelMixin): - name = 'search-list' # this name must match the name in urls.py - permission_classes = [AllowAny] - serializer_class = ItemSerializer - pagination_class = GetPostCursorPagination - # It is important to order the result by a unique identifier, because the search endpoint - # search overall collections and that the item name is only unique within a collection - # we must use the pk as ordering attribute, otherwise the cursor pagination will not work - ordering = ['pk'] - - def get_queryset(self): - queryset = Item.objects.filter(collection__published=True - ).prefetch_related('assets', 'links') - # harmonize GET and POST query - query_param = harmonize_post_get_for_search(self.request) - - # build queryset - - # if ids, then the other params will be ignored - if 'ids' in query_param: - queryset = queryset.filter_by_item_name(query_param['ids']) - else: - if 'bbox' in query_param: - queryset = queryset.filter_by_bbox(query_param['bbox']) - if 'datetime' in query_param: - queryset = queryset.filter_by_datetime(query_param['datetime']) - if 'collections' in query_param: - queryset = queryset.filter_by_collections(query_param['collections']) - if 'query' in query_param: - dict_query = json.loads(query_param['query']) - queryset = queryset.filter_by_query(dict_query) - if 'intersects' in query_param: - queryset = queryset.filter_by_intersects(json.dumps(query_param['intersects'])) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from SearchList() view:\n%s", - queryset.explain(verbose=True, analyze=True) - ) - logger.debug("The corresponding SQL statement:\n%s", queryset.query) - - return queryset - - def get_min_update_interval(self, queryset): - update_interval = queryset.filter(update_interval__gt=-1 - ).aggregate(Min('update_interval') - ).get('update_interval__min', None) - if update_interval is None: - update_interval = -1 - return update_interval - - def list(self, request, *args, **kwargs): - - validate_search_request = ValidateSearchRequest() - validate_search_request.validate(request) # validate the search request - queryset = self.filter_queryset(self.get_queryset()) - - page = self.paginate_queryset(queryset) - - if page is not None: - serializer = self.get_serializer(page, many=True) - else: - serializer = self.get_serializer(queryset, many=True) - - min_update_interval = None - if request.method in ['GET', 'HEAD', 'OPTIONS']: - if page is None: - queryset_paginated = queryset - else: - queryset_paginated = Item.objects.filter(pk__in=map(lambda item: item.pk, page)) - min_update_interval = self.get_min_update_interval(queryset_paginated) - - data = { - 'type': 'FeatureCollection', - 'timeStamp': utc_aware(datetime.utcnow()), - 'features': serializer.data, - 'links': get_relation_links(request, self.name) - } - - if page is not None: - response = self.paginator.get_paginated_response(data, request) - response = Response(data) - - return response, min_update_interval - - def get(self, request, *args, **kwargs): - response, min_update_interval = self.list(request, *args, **kwargs) - views_mixins.patch_cache_settings_by_update_interval(response, min_update_interval) - return response - - def post(self, request, *args, **kwargs): - response, _ = self.list(request, *args, **kwargs) - return response - - -class CollectionList(generics.GenericAPIView): - name = 'collections-list' # this name must match the name in urls.py - serializer_class = CollectionSerializer - # prefetch_related is a performance optimization to reduce the number - # of DB queries. - # see https://docs.djangoproject.com/en/3.1/ref/models/querysets/#prefetch-related - queryset = Collection.objects.filter(published=True).prefetch_related('providers', 'links') - ordering = ['name'] - - def get(self, request, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) - page = self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - else: - serializer = self.get_serializer(queryset, many=True) - - data = {'collections': serializer.data, 'links': get_relation_links(request, self.name)} - - if page is not None: - return self.get_paginated_response(data) - return Response(data) - - -@api_view(['POST']) -@permission_classes((permissions.AllowAny,)) -def recalculate_extent(request): - call_calculate_extent() - return Response() - - -class CollectionDetail( - generics.GenericAPIView, - mixins.RetrieveModelMixin, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin -): - # this name must match the name in urls.py and is used by the DestroyModelMixin - name = 'collection-detail' - serializer_class = CollectionSerializer - lookup_url_kwarg = "collection_name" - lookup_field = "name" - queryset = Collection.objects.all().prefetch_related('providers', 'links') - - @etag(get_collection_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_collection_etag) - def put(self, request, *args, **kwargs): - return self.upsert(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_collection_etag) - def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_collection_etag) - def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) - - def perform_upsert(self, serializer, lookup): - validate_renaming( - serializer, - self.kwargs['collection_name'], - extra_log={ - # pylint: disable=protected-access - 'request': self.request._request, - 'collection': self.kwargs['collection_name'] - } - ) - return super().perform_upsert(serializer, lookup) - - def perform_update(self, serializer, *args, **kwargs): - validate_renaming( - serializer, - self.kwargs['collection_name'], - extra_log={ - # pylint: disable=protected-access - 'request': self.request._request, - 'collection': self.kwargs['collection_name'] - } - ) - return super().perform_update(serializer, *args, **kwargs) - - -class ItemsList(generics.GenericAPIView): - serializer_class = ItemSerializer - ordering = ['name'] - name = 'items-list' # this name must match the name in urls.py - - def get_queryset(self): - # filter based on the url - queryset = Item.objects.filter( - # filter expired items - Q(properties_expires__gte=timezone.now()) | Q(properties_expires=None), - collection__name=self.kwargs['collection_name'] - ).prefetch_related(Prefetch('assets', queryset=Asset.objects.order_by('name')), 'links') - bbox = self.request.query_params.get('bbox', None) - date_time = self.request.query_params.get('datetime', None) - - if bbox: - queryset = queryset.filter_by_bbox(bbox) - - if date_time: - queryset = queryset.filter_by_datetime(date_time) - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from ItemList() view:\n%s", - queryset.explain(verbose=True, analyze=True) - ) - logger.debug("The corresponding SQL statement:\n%s", queryset.query) - - return queryset - - def list(self, request, *args, **kwargs): - validate_collection(self.kwargs) - queryset = self.filter_queryset(self.get_queryset()) - update_interval = Collection.objects.values('update_interval').get( - name=self.kwargs['collection_name'] - )['update_interval'] - page = self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - else: - serializer = self.get_serializer(queryset, many=True) - - data = { - 'type': 'FeatureCollection', - 'timeStamp': utc_aware(datetime.utcnow()), - 'features': serializer.data, - 'links': get_relation_links(request, self.name, [self.kwargs['collection_name']]) - } - - if page is not None: - response = self.get_paginated_response(data) - response = Response(data) - views_mixins.patch_cache_settings_by_update_interval(response, update_interval) - return response - - def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) - - -class ItemDetail( - generics.GenericAPIView, - views_mixins.RetrieveModelDynCacheMixin, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin -): - # this name must match the name in urls.py and is used by the DestroyModelMixin - name = 'item-detail' - serializer_class = ItemSerializer - lookup_url_kwarg = "item_name" - lookup_field = "name" - - def get_queryset(self): - # filter based on the url - queryset = Item.objects.filter( - # filter expired items - Q(properties_expires__gte=timezone.now()) | Q(properties_expires=None), - collection__name=self.kwargs['collection_name'] - ).prefetch_related(Prefetch('assets', queryset=Asset.objects.order_by('name')), 'links') - - if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: - logger.debug( - "Output of EXPLAIN.. ANALYZE from ItemDetail() view:\n%s", - queryset.explain(verbose=True, analyze=True) - ) - logger.debug("The corresponding SQL statement:\n%s", queryset.query) - - return queryset - - def perform_update(self, serializer): - collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) - validate_renaming( - serializer, - self.kwargs['item_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'item': self.kwargs['item_name'] - } - ) - serializer.save(collection=collection) - - def perform_upsert(self, serializer, lookup): - collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) - validate_renaming( - serializer, - self.kwargs['item_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'item': self.kwargs['item_name'] - } - ) - lookup['collection__name'] = collection.name - return serializer.upsert(lookup, collection=collection) - - @etag(get_item_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_item_etag) - def put(self, request, *args, **kwargs): - return self.upsert(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_item_etag) - def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_item_etag) - def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) - - -class AssetsList(generics.GenericAPIView): - name = 'assets-list' # this name must match the name in urls.py - serializer_class = AssetSerializer - pagination_class = None - - def get_queryset(self): - # filter based on the url - return Asset.objects.filter( - item__collection__name=self.kwargs['collection_name'], - item__name=self.kwargs['item_name'] - ).order_by('name') - - def get(self, request, *args, **kwargs): - validate_item(self.kwargs) - - queryset = self.filter_queryset(self.get_queryset()) - update_interval = Item.objects.values('update_interval').get( - collection__name=self.kwargs['collection_name'], - name=self.kwargs['item_name'], - )['update_interval'] - serializer = self.get_serializer(queryset, many=True) - - data = { - 'assets': serializer.data, - 'links': - get_relation_links( - request, self.name, [self.kwargs['collection_name'], self.kwargs['item_name']] - ) - } - response = Response(data) - views_mixins.patch_cache_settings_by_update_interval(response, update_interval) - return response - - -class AssetDetail( - generics.GenericAPIView, - views_mixins.UpdateInsertModelMixin, - views_mixins.DestroyModelMixin, - views_mixins.RetrieveModelDynCacheMixin -): - # this name must match the name in urls.py and is used by the DestroyModelMixin - name = 'asset-detail' - serializer_class = AssetSerializer - lookup_url_kwarg = "asset_name" - lookup_field = "name" - - def get_queryset(self): - # filter based on the url - return Asset.objects.filter( - Q(item__properties_expires=None) | Q(item__properties_expires__gte=timezone.now()), - item__collection__name=self.kwargs['collection_name'], - item__name=self.kwargs['item_name'] - ) - - def get_serializer(self, *args, **kwargs): - serializer_class = self.get_serializer_class() - kwargs.setdefault('context', self.get_serializer_context()) - item = get_object_or_404( - Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] - ) - serializer = serializer_class(*args, **kwargs) - - # for the validation the serializer needs to know the collection of the - # item. In case of upserting, the asset doesn't exist and thus the collection - # can't be read from the instance, which is why we pass the collection manually - # here. See serialiers.AssetBaseSerializer._validate_href_field - serializer.collection = item.collection - return serializer - - def _get_file_path(self, serializer, item, asset_name): - """Get the path to the file - - If the collection allows for external asset, and the file is specified - in the request, we set it directly. If the collection doesn't allow it, - error 400. - Otherwise we assemble the path from the file name, collection name as - well as the s3 bucket domain - """ - - if 'file' in serializer.validated_data: - file = serializer.validated_data['file'] - # setting the href makes the asset be external implicitly - is_external = True - else: - file = get_asset_path(item, asset_name) - is_external = False - - return file, is_external - - def perform_update(self, serializer): - item = get_object_or_404( - Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] - ) - validate_renaming( - serializer, - original_id=self.kwargs['asset_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'item': self.kwargs['item_name'], - 'asset': self.kwargs['asset_name'] - } - ) - file, is_external = self._get_file_path(serializer, item, self.kwargs['asset_name']) - return serializer.save(item=item, file=file, is_external=is_external) - - def perform_upsert(self, serializer, lookup): - item = get_object_or_404( - Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] - ) - - validate_renaming( - serializer, - original_id=self.kwargs['asset_name'], - extra_log={ - 'request': self.request._request, # pylint: disable=protected-access - 'collection': self.kwargs['collection_name'], - 'item': self.kwargs['item_name'], - 'asset': self.kwargs['asset_name'] - } - ) - lookup['item__collection__name'] = item.collection.name - lookup['item__name'] = item.name - - file, is_external = self._get_file_path(serializer, item, self.kwargs['asset_name']) - return serializer.upsert(lookup, item=item, file=file, is_external=is_external) - - @etag(get_asset_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_asset_etag) - def put(self, request, *args, **kwargs): - return self.upsert(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_asset_etag) - def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) - - # Here the etag is only added to support pre-conditional If-Match and If-Not-Match - @etag(get_asset_etag) - def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) - - -class AssetUploadBase(generics.GenericAPIView): - serializer_class = AssetUploadSerializer - lookup_url_kwarg = "upload_id" - lookup_field = "upload_id" - - def get_queryset(self): - return AssetUpload.objects.filter( - asset__item__collection__name=self.kwargs['collection_name'], - asset__item__name=self.kwargs['item_name'], - asset__name=self.kwargs['asset_name'] - ).prefetch_related('asset') - - def get_in_progress_queryset(self): - return self.get_queryset().filter(status=AssetUpload.Status.IN_PROGRESS) - - def get_asset_or_404(self): - return get_object_or_404( - Asset.objects.all(), - name=self.kwargs['asset_name'], - item__name=self.kwargs['item_name'], - item__collection__name=self.kwargs['collection_name'] - ) - - def _save_asset_upload(self, executor, serializer, key, asset, upload_id, urls): - try: - with transaction.atomic(): - serializer.save(asset=asset, upload_id=upload_id, urls=urls) - except IntegrityError as error: - logger.error( - 'Failed to create asset upload multipart: %s', - error, - extra={ - 'collection': asset.item.collection.name, - 'item': asset.item.name, - 'asset': asset.name - } - ) - if bool(self.get_in_progress_queryset()): - raise UploadInProgressError( - data={"upload_id": self.get_in_progress_queryset()[0].upload_id} - ) from None - raise - - def create_multipart_upload(self, executor, serializer, validated_data, asset): - key = get_asset_path(asset.item, asset.name) - - upload_id = executor.create_multipart_upload( - key, - asset, - validated_data['checksum_multihash'], - validated_data['update_interval'], - validated_data['content_encoding'] - ) - urls = [] - sorted_md5_parts = sorted(validated_data['md5_parts'], key=itemgetter('part_number')) - - try: - for part in sorted_md5_parts: - urls.append( - executor.create_presigned_url( - key, asset, part['part_number'], upload_id, part['md5'] - ) - ) - - self._save_asset_upload(executor, serializer, key, asset, upload_id, urls) - except APIException as err: - executor.abort_multipart_upload(key, asset, upload_id) - raise - - def complete_multipart_upload(self, executor, validated_data, asset_upload, asset): - key = get_asset_path(asset.item, asset.name) - parts = validated_data.get('parts', None) - if parts is None: - raise serializers.ValidationError({ - 'parts': _("Missing required field") - }, code='missing') - if len(parts) > asset_upload.number_parts: - raise serializers.ValidationError({'parts': [_("Too many parts")]}, code='invalid') - if len(parts) < asset_upload.number_parts: - raise serializers.ValidationError({'parts': [_("Too few parts")]}, code='invalid') - if asset_upload.status != AssetUpload.Status.IN_PROGRESS: - raise UploadNotInProgressError() - executor.complete_multipart_upload(key, asset, parts, asset_upload.upload_id) - asset_upload.update_asset_from_upload() - asset_upload.status = AssetUpload.Status.COMPLETED - asset_upload.ended = utc_aware(datetime.utcnow()) - asset_upload.urls = [] - asset_upload.save() - - def abort_multipart_upload(self, executor, asset_upload, asset): - key = get_asset_path(asset.item, asset.name) - executor.abort_multipart_upload(key, asset, asset_upload.upload_id) - asset_upload.status = AssetUpload.Status.ABORTED - asset_upload.ended = utc_aware(datetime.utcnow()) - asset_upload.urls = [] - asset_upload.save() - - def list_multipart_upload_parts(self, executor, asset_upload, asset, limit, offset): - key = get_asset_path(asset.item, asset.name) - return executor.list_upload_parts(key, asset, asset_upload.upload_id, limit, offset) - - -class AssetUploadsList(AssetUploadBase, mixins.ListModelMixin, views_mixins.CreateModelMixin): - - class ExternalDisallowedException(Exception): - pass - - def post(self, request, *args, **kwargs): - try: - return self.create(request, *args, **kwargs) - except self.ExternalDisallowedException as ex: - data = { - "code": 400, - "description": "Not allowed to create multipart uploads on external assets" - } - return Response(status=400, exception=True, data=data) - - def get(self, request, *args, **kwargs): - validate_asset(self.kwargs) - return self.list(request, *args, **kwargs) - - def get_success_headers(self, data): - return {'Location': '/'.join([self.request.build_absolute_uri(), data['upload_id']])} - - def perform_create(self, serializer): - data = serializer.validated_data - asset = self.get_asset_or_404() - collection = asset.item.collection - - if asset.is_external: - raise self.ExternalDisallowedException() - - s3_bucket = select_s3_bucket(collection.name) - executor = MultipartUpload(s3_bucket) - - self.create_multipart_upload(executor, serializer, data, asset) - - def get_queryset(self): - queryset = super().get_queryset() - - status = self.request.query_params.get('status', None) - if status: - queryset = queryset.filter_by_status(status) - - return queryset - - -class AssetUploadDetail(AssetUploadBase, mixins.RetrieveModelMixin, views_mixins.DestroyModelMixin): - - @etag(get_asset_upload_etag) - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - # @etag(get_asset_upload_etag) - # def delete(self, request, *args, **kwargs): - # return self.destroy(request, *args, **kwargs) - - -class AssetUploadComplete(AssetUploadBase, views_mixins.UpdateInsertModelMixin): - - def post(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) - - def perform_update(self, serializer): - asset = serializer.instance.asset - - collection = asset.item.collection - - s3_bucket = select_s3_bucket(collection.name) - executor = MultipartUpload(s3_bucket) - - self.complete_multipart_upload( - executor, serializer.validated_data, serializer.instance, asset - ) - - -class AssetUploadAbort(AssetUploadBase, views_mixins.UpdateInsertModelMixin): - - def post(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) - - def perform_update(self, serializer): - asset = serializer.instance.asset - - collection = asset.item.collection - - s3_bucket = select_s3_bucket(collection.name) - executor = MultipartUpload(s3_bucket) - self.abort_multipart_upload(executor, serializer.instance, asset) - - -class AssetUploadPartsList(AssetUploadBase): - serializer_class = AssetUploadPartsSerializer - pagination_class = ExtApiPagination - - def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) - - def list(self, request, *args, **kwargs): - asset_upload = self.get_object() - limit, offset = self.get_pagination_config(request) - - collection = asset_upload.asset.item.collection - s3_bucket = select_s3_bucket(collection.name) - - executor = MultipartUpload(s3_bucket) - - data, has_next = self.list_multipart_upload_parts( - executor, asset_upload, asset_upload.asset, limit, offset - ) - serializer = self.get_serializer(data) - - return self.get_paginated_response(serializer.data, has_next) - - def get_pagination_config(self, request): - return self.paginator.get_pagination_config(request) - - def get_paginated_response(self, data, has_next): # pylint: disable=arguments-differ - return self.paginator.get_paginated_response(data, has_next) diff --git a/app/stac_api/views/__init__.py b/app/stac_api/views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/stac_api/views/collection.py b/app/stac_api/views/collection.py new file mode 100644 index 00000000..1a14d4be --- /dev/null +++ b/app/stac_api/views/collection.py @@ -0,0 +1,258 @@ +import logging + +from django.conf import settings + +from rest_framework import generics +from rest_framework import mixins +from rest_framework.generics import get_object_or_404 +from rest_framework.response import Response +from rest_framework_condition import etag + +from stac_api.models import Collection +from stac_api.models import CollectionAsset +from stac_api.serializers.collection import CollectionAssetSerializer +from stac_api.serializers.collection import CollectionSerializer +from stac_api.serializers.utils import get_relation_links +from stac_api.utils import get_collection_asset_path +from stac_api.validators_view import validate_collection +from stac_api.validators_view import validate_renaming +from stac_api.views.general import get_etag +from stac_api.views.mixins import DestroyModelMixin +from stac_api.views.mixins import RetrieveModelDynCacheMixin +from stac_api.views.mixins import UpdateInsertModelMixin +from stac_api.views.mixins import patch_cache_settings_by_update_interval + +logger = logging.getLogger(__name__) + + +def get_collection_etag(request, *args, **kwargs): + '''Get the ETag for a collection object + + The ETag is an UUID4 computed on each object changes (including relations; provider and links) + ''' + tag = get_etag(Collection.objects.filter(name=kwargs['collection_name'])) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from get_collection_etag():\n%s", + Collection.objects.filter(name=kwargs['collection_name'] + ).explain(verbose=True, analyze=True) + ) + logger.debug( + "The corresponding SQL statement:\n%s", + Collection.objects.filter(name=kwargs['collection_name']).query + ) + + return tag + + +def get_collection_asset_etag(request, *args, **kwargs): + '''Get the ETag for a collection asset object + + The ETag is an UUID4 computed on each object changes + ''' + tag = get_etag( + CollectionAsset.objects.filter( + collection__name=kwargs['collection_name'], name=kwargs['asset_name'] + ) + ) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from get_collection_asset_etag():\n%s", + CollectionAsset.objects.filter(name=kwargs['asset_name'] + ).explain(verbose=True, analyze=True) + ) + logger.debug( + "The corresponding SQL statement:\n%s", + CollectionAsset.objects.filter(name=kwargs['asset_name']).query + ) + + return tag + + +class CollectionList(generics.GenericAPIView): + name = 'collections-list' # this name must match the name in urls.py + serializer_class = CollectionSerializer + # prefetch_related is a performance optimization to reduce the number + # of DB queries. + # see https://docs.djangoproject.com/en/3.1/ref/models/querysets/#prefetch-related + queryset = Collection.objects.filter(published=True).prefetch_related('providers', 'links') + ordering = ['name'] + + def get(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + else: + serializer = self.get_serializer(queryset, many=True) + + data = {'collections': serializer.data, 'links': get_relation_links(request, self.name)} + + if page is not None: + return self.get_paginated_response(data) + return Response(data) + + +class CollectionDetail( + generics.GenericAPIView, mixins.RetrieveModelMixin, UpdateInsertModelMixin, DestroyModelMixin +): + # this name must match the name in urls.py and is used by the DestroyModelMixin + name = 'collection-detail' + serializer_class = CollectionSerializer + lookup_url_kwarg = "collection_name" + lookup_field = "name" + queryset = Collection.objects.all().prefetch_related('providers', 'links') + + @etag(get_collection_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_etag) + def put(self, request, *args, **kwargs): + return self.upsert(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_etag) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_etag) + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + def perform_upsert(self, serializer, lookup): + validate_renaming( + serializer, + self.kwargs['collection_name'], + extra_log={ + # pylint: disable=protected-access + 'request': self.request._request, + 'collection': self.kwargs['collection_name'] + } + ) + return super().perform_upsert(serializer, lookup) + + def perform_update(self, serializer, *args, **kwargs): + validate_renaming( + serializer, + self.kwargs['collection_name'], + extra_log={ + # pylint: disable=protected-access + 'request': self.request._request, + 'collection': self.kwargs['collection_name'] + } + ) + return super().perform_update(serializer, *args, **kwargs) + + +class CollectionAssetsList(generics.GenericAPIView): + name = 'collection-assets-list' # this name must match the name in urls.py + serializer_class = CollectionAssetSerializer + pagination_class = None + + def get_queryset(self): + # filter based on the url + return CollectionAsset.objects.filter(collection__name=self.kwargs['collection_name'] + ).order_by('name') + + def get(self, request, *args, **kwargs): + validate_collection(self.kwargs) + + queryset = self.filter_queryset(self.get_queryset()) + update_interval = Collection.objects.values('update_interval').get( + name=self.kwargs['collection_name'] + )['update_interval'] + serializer = self.get_serializer(queryset, many=True) + + data = { + 'assets': serializer.data, + 'links': get_relation_links(request, self.name, [self.kwargs['collection_name']]) + } + response = Response(data) + patch_cache_settings_by_update_interval(response, update_interval) + return response + + +class CollectionAssetDetail( + generics.GenericAPIView, UpdateInsertModelMixin, DestroyModelMixin, RetrieveModelDynCacheMixin +): + # this name must match the name in urls.py and is used by the DestroyModelMixin + name = 'collection-asset-detail' + serializer_class = CollectionAssetSerializer + lookup_url_kwarg = "asset_name" + lookup_field = "name" + + def get_queryset(self): + # filter based on the url + return CollectionAsset.objects.filter(collection__name=self.kwargs['collection_name']) + + def get_serializer(self, *args, **kwargs): + serializer_class = self.get_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + serializer = serializer_class(*args, **kwargs) + + # for the validation the serializer needs to know the collection of the + # asset. In case of inserting, the asset doesn't exist and thus the collection + # can't be read from the instance, which is why we pass the collection manually + # here. See serializers.AssetBaseSerializer._validate_href_field + serializer.collection = collection + return serializer + + def perform_update(self, serializer): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'asset': self.kwargs['asset_name'] + } + ) + return serializer.save( + collection=collection, + file=get_collection_asset_path(collection, self.kwargs['asset_name']) + ) + + def perform_upsert(self, serializer, lookup): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'asset': self.kwargs['asset_name'] + } + ) + lookup['collection__name'] = collection.name + + return serializer.upsert( + lookup, + collection=collection, + file=get_collection_asset_path(collection, self.kwargs['asset_name']) + ) + + @etag(get_collection_asset_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_asset_etag) + def put(self, request, *args, **kwargs): + return self.upsert(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_asset_etag) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_collection_asset_etag) + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) diff --git a/app/stac_api/views/general.py b/app/stac_api/views/general.py new file mode 100644 index 00000000..3b4d3a71 --- /dev/null +++ b/app/stac_api/views/general.py @@ -0,0 +1,161 @@ +import json +import logging +from datetime import datetime + +from django.conf import settings +from django.db.models import Min +from django.utils.translation import gettext_lazy as _ + +from rest_framework import generics +from rest_framework import mixins +from rest_framework import permissions +from rest_framework.decorators import api_view +from rest_framework.decorators import permission_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response + +from stac_api.models import Item +from stac_api.models import LandingPage +from stac_api.pagination import GetPostCursorPagination +from stac_api.serializers.general import ConformancePageSerializer +from stac_api.serializers.general import LandingPageSerializer +from stac_api.serializers.item import ItemSerializer +from stac_api.serializers.utils import get_relation_links +from stac_api.utils import call_calculate_extent +from stac_api.utils import harmonize_post_get_for_search +from stac_api.utils import is_api_version_1 +from stac_api.utils import utc_aware +from stac_api.validators_serializer import ValidateSearchRequest +from stac_api.views.mixins import patch_cache_settings_by_update_interval + +logger = logging.getLogger(__name__) + + +def get_etag(queryset): + if queryset.exists(): + return list(queryset.only('etag').values('etag').first().values())[0] + return None + + +class LandingPageDetail(generics.RetrieveAPIView): + name = 'landing-page' # this name must match the name in urls.py + serializer_class = LandingPageSerializer + queryset = LandingPage.objects.all() + + def get_object(self): + if not is_api_version_1(self.request): + return LandingPage.objects.get(version='v0.9') + return LandingPage.objects.get(version='v1') + + +class ConformancePageDetail(generics.RetrieveAPIView): + name = 'conformance' # this name must match the name in urls.py + serializer_class = ConformancePageSerializer + queryset = LandingPage.objects.all() + + def get_object(self): + if not is_api_version_1(self.request): + return LandingPage.objects.get(version='v0.9') + return LandingPage.objects.get(version='v1') + + +class SearchList(generics.GenericAPIView, mixins.ListModelMixin): + name = 'search-list' # this name must match the name in urls.py + permission_classes = [AllowAny] + serializer_class = ItemSerializer + pagination_class = GetPostCursorPagination + # It is important to order the result by a unique identifier, because the search endpoint + # search overall collections and that the item name is only unique within a collection + # we must use the pk as ordering attribute, otherwise the cursor pagination will not work + ordering = ['pk'] + + def get_queryset(self): + queryset = Item.objects.filter(collection__published=True + ).prefetch_related('assets', 'links') + # harmonize GET and POST query + query_param = harmonize_post_get_for_search(self.request) + + # build queryset + + # if ids, then the other params will be ignored + if 'ids' in query_param: + queryset = queryset.filter_by_item_name(query_param['ids']) + else: + if 'bbox' in query_param: + queryset = queryset.filter_by_bbox(query_param['bbox']) + if 'datetime' in query_param: + queryset = queryset.filter_by_datetime(query_param['datetime']) + if 'collections' in query_param: + queryset = queryset.filter_by_collections(query_param['collections']) + if 'query' in query_param: + dict_query = json.loads(query_param['query']) + queryset = queryset.filter_by_query(dict_query) + if 'intersects' in query_param: + queryset = queryset.filter_by_intersects(json.dumps(query_param['intersects'])) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from SearchList() view:\n%s", + queryset.explain(verbose=True, analyze=True) + ) + logger.debug("The corresponding SQL statement:\n%s", queryset.query) + + return queryset + + def get_min_update_interval(self, queryset): + update_interval = queryset.filter(update_interval__gt=-1 + ).aggregate(Min('update_interval') + ).get('update_interval__min', None) + if update_interval is None: + update_interval = -1 + return update_interval + + def list(self, request, *args, **kwargs): + + validate_search_request = ValidateSearchRequest() + validate_search_request.validate(request) # validate the search request + queryset = self.filter_queryset(self.get_queryset()) + + page = self.paginate_queryset(queryset) + + if page is not None: + serializer = self.get_serializer(page, many=True) + else: + serializer = self.get_serializer(queryset, many=True) + + min_update_interval = None + if request.method in ['GET', 'HEAD', 'OPTIONS']: + if page is None: + queryset_paginated = queryset + else: + queryset_paginated = Item.objects.filter(pk__in=map(lambda item: item.pk, page)) + min_update_interval = self.get_min_update_interval(queryset_paginated) + + data = { + 'type': 'FeatureCollection', + 'timeStamp': utc_aware(datetime.utcnow()), + 'features': serializer.data, + 'links': get_relation_links(request, self.name) + } + + if page is not None: + response = self.paginator.get_paginated_response(data, request) + response = Response(data) + + return response, min_update_interval + + def get(self, request, *args, **kwargs): + response, min_update_interval = self.list(request, *args, **kwargs) + patch_cache_settings_by_update_interval(response, min_update_interval) + return response + + def post(self, request, *args, **kwargs): + response, _ = self.list(request, *args, **kwargs) + return response + + +@api_view(['POST']) +@permission_classes((permissions.AllowAny,)) +def recalculate_extent(request): + call_calculate_extent() + return Response() diff --git a/app/stac_api/views/item.py b/app/stac_api/views/item.py new file mode 100644 index 00000000..4efa73cf --- /dev/null +++ b/app/stac_api/views/item.py @@ -0,0 +1,363 @@ +import logging +from datetime import datetime + +from django.conf import settings +from django.db.models import Prefetch +from django.db.models import Q +from django.utils import timezone + +from rest_framework import generics +from rest_framework.generics import get_object_or_404 +from rest_framework.response import Response +from rest_framework_condition import etag + +from stac_api.models import Asset +from stac_api.models import Collection +from stac_api.models import Item +from stac_api.serializers.item import AssetSerializer +from stac_api.serializers.item import ItemSerializer +from stac_api.serializers.utils import get_relation_links +from stac_api.utils import get_asset_path +from stac_api.utils import utc_aware +from stac_api.validators_view import validate_collection +from stac_api.validators_view import validate_item +from stac_api.validators_view import validate_renaming +from stac_api.views import mixins +from stac_api.views.general import get_etag + +logger = logging.getLogger(__name__) + + +def get_item_etag(request, *args, **kwargs): + '''Get the ETag for a item object + + The ETag is an UUID4 computed on each object changes (including relations; assets and links) + ''' + tag = get_etag( + Item.objects.filter(collection__name=kwargs['collection_name'], name=kwargs['item_name']) + ) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from get_item_etag():\n%s", + Item.objects.filter( + collection__name=kwargs['collection_name'], name=kwargs['item_name'] + ).explain(verbose=True, analyze=True) + ) + logger.debug( + "The corresponding SQL statement:\n%s", + Item.objects.filter( + collection__name=kwargs['collection_name'], name=kwargs['item_name'] + ).query + ) + + return tag + + +def get_asset_etag(request, *args, **kwargs): + '''Get the ETag for a asset object + + The ETag is an UUID4 computed on each object changes + ''' + tag = get_etag( + Asset.objects.filter( + item__collection__name=kwargs['collection_name'], + item__name=kwargs['item_name'], + name=kwargs['asset_name'] + ) + ) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from get_asset_etag():\n%s", + Asset.objects.filter(item__name=kwargs['item_name'], + name=kwargs['asset_name']).explain(verbose=True, analyze=True) + ) + logger.debug( + "The corresponding SQL statement:\n%s", + Asset.objects.filter(item__name=kwargs['item_name'], name=kwargs['asset_name']).query + ) + + return tag + + +class ItemsList(generics.GenericAPIView): + serializer_class = ItemSerializer + ordering = ['name'] + name = 'items-list' # this name must match the name in urls.py + + def get_queryset(self): + # filter based on the url + queryset = Item.objects.filter( + # filter expired items + Q(properties_expires__gte=timezone.now()) | Q(properties_expires=None), + collection__name=self.kwargs['collection_name'] + ).prefetch_related(Prefetch('assets', queryset=Asset.objects.order_by('name')), 'links') + bbox = self.request.query_params.get('bbox', None) + date_time = self.request.query_params.get('datetime', None) + + if bbox: + queryset = queryset.filter_by_bbox(bbox) + + if date_time: + queryset = queryset.filter_by_datetime(date_time) + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from ItemList() view:\n%s", + queryset.explain(verbose=True, analyze=True) + ) + logger.debug("The corresponding SQL statement:\n%s", queryset.query) + + return queryset + + def list(self, request, *args, **kwargs): + validate_collection(self.kwargs) + queryset = self.filter_queryset(self.get_queryset()) + update_interval = Collection.objects.values('update_interval').get( + name=self.kwargs['collection_name'] + )['update_interval'] + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + else: + serializer = self.get_serializer(queryset, many=True) + + data = { + 'type': 'FeatureCollection', + 'timeStamp': utc_aware(datetime.utcnow()), + 'features': serializer.data, + 'links': get_relation_links(request, self.name, [self.kwargs['collection_name']]) + } + + if page is not None: + response = self.get_paginated_response(data) + response = Response(data) + mixins.patch_cache_settings_by_update_interval(response, update_interval) + return response + + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + +class ItemDetail( + generics.GenericAPIView, + mixins.RetrieveModelDynCacheMixin, + mixins.UpdateInsertModelMixin, + mixins.DestroyModelMixin +): + # this name must match the name in urls.py and is used by the DestroyModelMixin + name = 'item-detail' + serializer_class = ItemSerializer + lookup_url_kwarg = "item_name" + lookup_field = "name" + + def get_queryset(self): + # filter based on the url + queryset = Item.objects.filter( + # filter expired items + Q(properties_expires__gte=timezone.now()) | Q(properties_expires=None), + collection__name=self.kwargs['collection_name'] + ).prefetch_related(Prefetch('assets', queryset=Asset.objects.order_by('name')), 'links') + + if settings.DEBUG_ENABLE_DB_EXPLAIN_ANALYZE: + logger.debug( + "Output of EXPLAIN.. ANALYZE from ItemDetail() view:\n%s", + queryset.explain(verbose=True, analyze=True) + ) + logger.debug("The corresponding SQL statement:\n%s", queryset.query) + + return queryset + + def perform_update(self, serializer): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + self.kwargs['item_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'item': self.kwargs['item_name'] + } + ) + serializer.save(collection=collection) + + def perform_upsert(self, serializer, lookup): + collection = get_object_or_404(Collection, name=self.kwargs['collection_name']) + validate_renaming( + serializer, + self.kwargs['item_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'item': self.kwargs['item_name'] + } + ) + lookup['collection__name'] = collection.name + return serializer.upsert(lookup, collection=collection) + + @etag(get_item_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_item_etag) + def put(self, request, *args, **kwargs): + return self.upsert(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_item_etag) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_item_etag) + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + +class AssetsList(generics.GenericAPIView): + name = 'assets-list' # this name must match the name in urls.py + serializer_class = AssetSerializer + pagination_class = None + + def get_queryset(self): + # filter based on the url + return Asset.objects.filter( + item__collection__name=self.kwargs['collection_name'], + item__name=self.kwargs['item_name'] + ).order_by('name') + + def get(self, request, *args, **kwargs): + validate_item(self.kwargs) + + queryset = self.filter_queryset(self.get_queryset()) + update_interval = Item.objects.values('update_interval').get( + collection__name=self.kwargs['collection_name'], + name=self.kwargs['item_name'], + )['update_interval'] + serializer = self.get_serializer(queryset, many=True) + + data = { + 'assets': serializer.data, + 'links': + get_relation_links( + request, self.name, [self.kwargs['collection_name'], self.kwargs['item_name']] + ) + } + response = Response(data) + mixins.patch_cache_settings_by_update_interval(response, update_interval) + return response + + +class AssetDetail( + generics.GenericAPIView, + mixins.UpdateInsertModelMixin, + mixins.DestroyModelMixin, + mixins.RetrieveModelDynCacheMixin +): + # this name must match the name in urls.py and is used by the DestroyModelMixin + name = 'asset-detail' + serializer_class = AssetSerializer + lookup_url_kwarg = "asset_name" + lookup_field = "name" + + def get_queryset(self): + # filter based on the url + return Asset.objects.filter( + Q(item__properties_expires=None) | Q(item__properties_expires__gte=timezone.now()), + item__collection__name=self.kwargs['collection_name'], + item__name=self.kwargs['item_name'] + ) + + def get_serializer(self, *args, **kwargs): + serializer_class = self.get_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + item = get_object_or_404( + Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] + ) + serializer = serializer_class(*args, **kwargs) + + # for the validation the serializer needs to know the collection of the + # item. In case of upserting, the asset doesn't exist and thus the collection + # can't be read from the instance, which is why we pass the collection manually + # here. See serialiers.AssetBaseSerializer._validate_href_field + serializer.collection = item.collection + return serializer + + def _get_file_path(self, serializer, item, asset_name): + """Get the path to the file + + If the collection allows for external asset, and the file is specified + in the request, we set it directly. If the collection doesn't allow it, + error 400. + Otherwise we assemble the path from the file name, collection name as + well as the s3 bucket domain + """ + + if 'file' in serializer.validated_data: + file = serializer.validated_data['file'] + # setting the href makes the asset be external implicitly + is_external = True + else: + file = get_asset_path(item, asset_name) + is_external = False + + return file, is_external + + def perform_update(self, serializer): + item = get_object_or_404( + Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] + ) + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'item': self.kwargs['item_name'], + 'asset': self.kwargs['asset_name'] + } + ) + file, is_external = self._get_file_path(serializer, item, self.kwargs['asset_name']) + return serializer.save(item=item, file=file, is_external=is_external) + + def perform_upsert(self, serializer, lookup): + item = get_object_or_404( + Item, collection__name=self.kwargs['collection_name'], name=self.kwargs['item_name'] + ) + + validate_renaming( + serializer, + original_id=self.kwargs['asset_name'], + extra_log={ + 'request': self.request._request, # pylint: disable=protected-access + 'collection': self.kwargs['collection_name'], + 'item': self.kwargs['item_name'], + 'asset': self.kwargs['asset_name'] + } + ) + lookup['item__collection__name'] = item.collection.name + lookup['item__name'] = item.name + + file, is_external = self._get_file_path(serializer, item, self.kwargs['asset_name']) + return serializer.upsert(lookup, item=item, file=file, is_external=is_external) + + @etag(get_asset_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_asset_etag) + def put(self, request, *args, **kwargs): + return self.upsert(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_asset_etag) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + # Here the etag is only added to support pre-conditional If-Match and If-Not-Match + @etag(get_asset_etag) + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) diff --git a/app/stac_api/views_mixins.py b/app/stac_api/views/mixins.py similarity index 99% rename from app/stac_api/views_mixins.py rename to app/stac_api/views/mixins.py index 2c63c70f..8e8f1569 100644 --- a/app/stac_api/views_mixins.py +++ b/app/stac_api/views/mixins.py @@ -11,7 +11,7 @@ from rest_framework import status from rest_framework.response import Response -from stac_api.serializers_utils import get_parent_link +from stac_api.serializers.utils import get_parent_link from stac_api.utils import get_dynamic_max_age_value from stac_api.utils import get_link diff --git a/app/stac_api/views_test.py b/app/stac_api/views/test.py similarity index 69% rename from app/stac_api/views_test.py rename to app/stac_api/views/test.py index a23f9e3e..85b7d1ab 100644 --- a/app/stac_api/views_test.py +++ b/app/stac_api/views/test.py @@ -3,9 +3,10 @@ from rest_framework import generics from stac_api.models import LandingPage -from stac_api.views import AssetDetail -from stac_api.views import CollectionDetail -from stac_api.views import ItemDetail +from stac_api.views.collection import CollectionAssetDetail +from stac_api.views.collection import CollectionDetail +from stac_api.views.item import AssetDetail +from stac_api.views.item import ItemDetail logger = logging.getLogger(__name__) @@ -39,3 +40,10 @@ class TestAssetUpsertHttp500(AssetDetail): def perform_upsert(self, serializer, lookup): super().perform_upsert(serializer, lookup) raise AttributeError('test exception') + + +class TestCollectionAssetUpsertHttp500(CollectionAssetDetail): + + def perform_upsert(self, serializer, lookup): + super().perform_upsert(serializer, lookup) + raise AttributeError('test exception') diff --git a/app/stac_api/views/upload.py b/app/stac_api/views/upload.py new file mode 100644 index 00000000..7927b6fc --- /dev/null +++ b/app/stac_api/views/upload.py @@ -0,0 +1,447 @@ +import logging +from datetime import datetime +from operator import itemgetter + +from django.db import IntegrityError +from django.db import transaction +from django.utils.translation import gettext_lazy as _ + +from rest_framework import generics +from rest_framework import mixins +from rest_framework import serializers +from rest_framework.exceptions import APIException +from rest_framework.generics import get_object_or_404 +from rest_framework.response import Response +from rest_framework_condition import etag + +from stac_api.exceptions import UploadInProgressError +from stac_api.exceptions import UploadNotInProgressError +from stac_api.models import Asset +from stac_api.models import AssetUpload +from stac_api.models import BaseAssetUpload +from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload +from stac_api.pagination import ExtApiPagination +from stac_api.s3_multipart_upload import MultipartUpload +from stac_api.serializers.upload import AssetUploadPartsSerializer +from stac_api.serializers.upload import AssetUploadSerializer +from stac_api.serializers.upload import CollectionAssetUploadSerializer +from stac_api.utils import get_asset_path +from stac_api.utils import get_collection_asset_path +from stac_api.utils import select_s3_bucket +from stac_api.utils import utc_aware +from stac_api.validators_view import validate_asset +from stac_api.validators_view import validate_collection_asset +from stac_api.views.general import get_etag +from stac_api.views.mixins import CreateModelMixin +from stac_api.views.mixins import DestroyModelMixin +from stac_api.views.mixins import UpdateInsertModelMixin + +logger = logging.getLogger(__name__) + + +def get_asset_upload_etag(request, *args, **kwargs): + '''Get the ETag for an asset upload object + + The ETag is an UUID4 computed on each object changes + ''' + return get_etag( + AssetUpload.objects.filter( + asset__item__collection__name=kwargs['collection_name'], + asset__item__name=kwargs['item_name'], + asset__name=kwargs['asset_name'], + upload_id=kwargs['upload_id'] + ) + ) + + +def get_collection_asset_upload_etag(request, *args, **kwargs): + '''Get the ETag for a collection asset upload object + + The ETag is an UUID4 computed on each object changes + ''' + return get_etag( + CollectionAssetUpload.objects.filter( + asset__collection__name=kwargs['collection_name'], + asset__name=kwargs['asset_name'], + upload_id=kwargs['upload_id'] + ) + ) + + +class SharedAssetUploadBase(generics.GenericAPIView): + """SharedAssetUploadBase provides a base view for asset uploads and collection asset uploads. + """ + lookup_url_kwarg = "upload_id" + lookup_field = "upload_id" + + def get_queryset(self): + raise NotImplementedError("get_queryset() not implemented") + + def get_in_progress_queryset(self): + return self.get_queryset().filter(status=BaseAssetUpload.Status.IN_PROGRESS) + + def get_asset_or_404(self): + raise NotImplementedError("get_asset_or_404() not implemented") + + def log_extra(self, asset): + if isinstance(asset, CollectionAsset): + return {'collection': asset.collection.name, 'asset': asset.name} + return { + 'collection': asset.item.collection.name, 'item': asset.item.name, 'asset': asset.name + } + + def get_path(self, asset): + if isinstance(asset, CollectionAsset): + return get_collection_asset_path(asset.collection, asset.name) + return get_asset_path(asset.item, asset.name) + + def _save_asset_upload(self, executor, serializer, key, asset, upload_id, urls): + try: + with transaction.atomic(): + serializer.save(asset=asset, upload_id=upload_id, urls=urls) + except IntegrityError as error: + logger.error( + 'Failed to create asset upload multipart: %s', error, extra=self.log_extra(asset) + ) + if bool(self.get_in_progress_queryset()): + raise UploadInProgressError( + data={"upload_id": self.get_in_progress_queryset()[0].upload_id} + ) from None + raise + + def create_multipart_upload(self, executor, serializer, validated_data, asset): + key = self.get_path(asset) + + upload_id = executor.create_multipart_upload( + key, + asset, + validated_data['checksum_multihash'], + validated_data['update_interval'], + validated_data['content_encoding'] + ) + urls = [] + sorted_md5_parts = sorted(validated_data['md5_parts'], key=itemgetter('part_number')) + + try: + for part in sorted_md5_parts: + urls.append( + executor.create_presigned_url( + key, asset, part['part_number'], upload_id, part['md5'] + ) + ) + + self._save_asset_upload(executor, serializer, key, asset, upload_id, urls) + except APIException as err: + executor.abort_multipart_upload(key, asset, upload_id) + raise + + def complete_multipart_upload(self, executor, validated_data, asset_upload, asset): + key = self.get_path(asset) + parts = validated_data.get('parts', None) + if parts is None: + raise serializers.ValidationError({ + 'parts': _("Missing required field") + }, code='missing') + if len(parts) > asset_upload.number_parts: + raise serializers.ValidationError({'parts': [_("Too many parts")]}, code='invalid') + if len(parts) < asset_upload.number_parts: + raise serializers.ValidationError({'parts': [_("Too few parts")]}, code='invalid') + if asset_upload.status != BaseAssetUpload.Status.IN_PROGRESS: + raise UploadNotInProgressError() + executor.complete_multipart_upload(key, asset, parts, asset_upload.upload_id) + asset_upload.update_asset_from_upload() + asset_upload.status = BaseAssetUpload.Status.COMPLETED + asset_upload.ended = utc_aware(datetime.utcnow()) + asset_upload.urls = [] + asset_upload.save() + + def abort_multipart_upload(self, executor, asset_upload, asset): + key = self.get_path(asset) + executor.abort_multipart_upload(key, asset, asset_upload.upload_id) + asset_upload.status = BaseAssetUpload.Status.ABORTED + asset_upload.ended = utc_aware(datetime.utcnow()) + asset_upload.urls = [] + asset_upload.save() + + def list_multipart_upload_parts(self, executor, asset_upload, asset, limit, offset): + key = self.get_path(asset) + return executor.list_upload_parts(key, asset, asset_upload.upload_id, limit, offset) + + +class AssetUploadBase(SharedAssetUploadBase): + """AssetUploadBase is the base for all asset (not collection asset) upload views. + """ + serializer_class = AssetUploadSerializer + + def get_queryset(self): + return AssetUpload.objects.filter( + asset__item__collection__name=self.kwargs['collection_name'], + asset__item__name=self.kwargs['item_name'], + asset__name=self.kwargs['asset_name'] + ).prefetch_related('asset') + + def get_asset_or_404(self): + return get_object_or_404( + Asset.objects.all(), + name=self.kwargs['asset_name'], + item__name=self.kwargs['item_name'], + item__collection__name=self.kwargs['collection_name'] + ) + + +class AssetUploadsList(AssetUploadBase, mixins.ListModelMixin, CreateModelMixin): + + class ExternalDisallowedException(Exception): + pass + + def post(self, request, *args, **kwargs): + try: + return self.create(request, *args, **kwargs) + except self.ExternalDisallowedException as ex: + data = { + "code": 400, + "description": "Not allowed to create multipart uploads on external assets" + } + return Response(status=400, exception=True, data=data) + + def get(self, request, *args, **kwargs): + validate_asset(self.kwargs) + return self.list(request, *args, **kwargs) + + def get_success_headers(self, data): + return {'Location': '/'.join([self.request.build_absolute_uri(), data['upload_id']])} + + def perform_create(self, serializer): + data = serializer.validated_data + asset = self.get_asset_or_404() + collection = asset.item.collection + + if asset.is_external: + raise self.ExternalDisallowedException() + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.create_multipart_upload(executor, serializer, data, asset) + + def get_queryset(self): + queryset = super().get_queryset() + + status = self.request.query_params.get('status', None) + if status: + queryset = queryset.filter_by_status(status) + + return queryset + + +class AssetUploadDetail(AssetUploadBase, mixins.RetrieveModelMixin, DestroyModelMixin): + + @etag(get_asset_upload_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class AssetUploadComplete(AssetUploadBase, UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.item.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.complete_multipart_upload( + executor, serializer.validated_data, serializer.instance, asset + ) + + +class AssetUploadAbort(AssetUploadBase, UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.item.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + self.abort_multipart_upload(executor, serializer.instance, asset) + + +class AssetUploadPartsList(AssetUploadBase): + serializer_class = AssetUploadPartsSerializer + pagination_class = ExtApiPagination + + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def list(self, request, *args, **kwargs): + asset_upload = self.get_object() + limit, offset = self.get_pagination_config(request) + + collection = asset_upload.asset.item.collection + s3_bucket = select_s3_bucket(collection.name) + + executor = MultipartUpload(s3_bucket) + + data, has_next = self.list_multipart_upload_parts( + executor, asset_upload, asset_upload.asset, limit, offset + ) + serializer = self.get_serializer(data) + + return self.get_paginated_response(serializer.data, has_next) + + def get_pagination_config(self, request): + return self.paginator.get_pagination_config(request) + + def get_paginated_response(self, data, has_next): # pylint: disable=arguments-differ + return self.paginator.get_paginated_response(data, has_next) + + +class CollectionAssetUploadBase(SharedAssetUploadBase): + """CollectionAssetUploadBase is the base for all collection asset upload views. + """ + serializer_class = CollectionAssetUploadSerializer + + def get_queryset(self): + return CollectionAssetUpload.objects.filter( + asset__collection__name=self.kwargs['collection_name'], + asset__name=self.kwargs['asset_name'] + ).prefetch_related('asset') + + def get_asset_or_404(self): + return get_object_or_404( + CollectionAsset.objects.all(), + name=self.kwargs['asset_name'], + collection__name=self.kwargs['collection_name'] + ) + + +class CollectionAssetUploadsList( + CollectionAssetUploadBase, mixins.ListModelMixin, CreateModelMixin +): + + class ExternalDisallowedException(Exception): + pass + + def post(self, request, *args, **kwargs): + try: + return self.create(request, *args, **kwargs) + except self.ExternalDisallowedException as ex: + data = { + "code": 400, + "description": "Not allowed to create multipart uploads on external assets" + } + return Response(status=400, exception=True, data=data) + + def get(self, request, *args, **kwargs): + validate_collection_asset(self.kwargs) + return self.list(request, *args, **kwargs) + + def get_success_headers(self, data): + return {'Location': '/'.join([self.request.build_absolute_uri(), data['upload_id']])} + + def perform_create(self, serializer): + data = serializer.validated_data + asset = self.get_asset_or_404() + collection = asset.collection + + if asset.is_external: + raise self.ExternalDisallowedException() + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.create_multipart_upload(executor, serializer, data, asset) + + def get_queryset(self): + queryset = super().get_queryset() + + status = self.request.query_params.get('status', None) + if status: + queryset = queryset.filter_by_status(status) + + return queryset + + +class CollectionAssetUploadDetail( + CollectionAssetUploadBase, mixins.RetrieveModelMixin, DestroyModelMixin +): + + @etag(get_collection_asset_upload_etag) + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class CollectionAssetUploadComplete(CollectionAssetUploadBase, UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + + self.complete_multipart_upload( + executor, serializer.validated_data, serializer.instance, asset + ) + + +class CollectionAssetUploadAbort(CollectionAssetUploadBase, UpdateInsertModelMixin): + + def post(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def perform_update(self, serializer): + asset = serializer.instance.asset + + collection = asset.collection + + s3_bucket = select_s3_bucket(collection.name) + executor = MultipartUpload(s3_bucket) + self.abort_multipart_upload(executor, serializer.instance, asset) + + +class CollectionAssetUploadPartsList(CollectionAssetUploadBase): + serializer_class = AssetUploadPartsSerializer + pagination_class = ExtApiPagination + + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def list(self, request, *args, **kwargs): + asset_upload = self.get_object() + limit, offset = self.get_pagination_config(request) + + collection = asset_upload.asset.collection + s3_bucket = select_s3_bucket(collection.name) + + executor = MultipartUpload(s3_bucket) + + data, has_next = self.list_multipart_upload_parts( + executor, asset_upload, asset_upload.asset, limit, offset + ) + serializer = self.get_serializer(data) + + return self.get_paginated_response(serializer.data, has_next) + + def get_pagination_config(self, request): + return self.paginator.get_pagination_config(request) + + def get_paginated_response(self, data, has_next): # pylint: disable=arguments-differ + return self.paginator.get_paginated_response(data, has_next) diff --git a/app/tests/tests_09/test_serializer.py b/app/tests/tests_09/test_serializer.py index bd9a40c2..66374304 100644 --- a/app/tests/tests_09/test_serializer.py +++ b/app/tests/tests_09/test_serializer.py @@ -13,9 +13,9 @@ from rest_framework.test import APIRequestFactory from stac_api.models import get_asset_path -from stac_api.serializers import AssetSerializer -from stac_api.serializers import CollectionSerializer -from stac_api.serializers import ItemSerializer +from stac_api.serializers.collection import CollectionSerializer +from stac_api.serializers.item import AssetSerializer +from stac_api.serializers.item import ItemSerializer from stac_api.utils import get_link from stac_api.utils import isoformat from stac_api.utils import utc_aware diff --git a/app/tests/tests_09/test_serializer_asset_upload.py b/app/tests/tests_09/test_serializer_asset_upload.py index 7d0e8ce7..e5a54e43 100644 --- a/app/tests/tests_09/test_serializer_asset_upload.py +++ b/app/tests/tests_09/test_serializer_asset_upload.py @@ -7,7 +7,7 @@ from rest_framework.exceptions import ValidationError from stac_api.models import AssetUpload -from stac_api.serializers import AssetUploadSerializer +from stac_api.serializers.upload import AssetUploadSerializer from stac_api.utils import get_sha256_multihash from tests.tests_09.base_test import STAC_BASE_V diff --git a/app/tests/tests_10/base_test.py b/app/tests/tests_10/base_test.py index bbfc86ba..cf02f2f8 100644 --- a/app/tests/tests_10/base_test.py +++ b/app/tests/tests_10/base_test.py @@ -302,6 +302,56 @@ def check_stac_asset(self, expected, current, collection, item, ignore=None): ] self._check_stac_links('asset.links', links, current['links']) + def check_stac_collection_asset(self, expected, current, collection, ignore=None): + '''Check a STAC Collection Asset data + + Check if the `current` collection asset data match the `expected`. This check is a subset + check which means that if a value is missing from `current`, then it raises a Test Assert, + while if a value is in `current` but not in `expected`, the test passed. The functions + knows also the STAC Spec and does some check based on it. + + Args: + expected: dict + Expected STAC Asset + current: dict + Current STAC Asset to test + ignore: list(string) | None + List of keys to ignore in the test + ''' + if ignore is None: + ignore = [] + self._check_stac_dictsubset('asset', expected, current, ignore=ignore) + + # check required fields + for key in ['links', 'id', 'type', 'href']: + if key in ignore: + logger.info('Ignoring key %s in asset', key) + continue + self.assertIn(key, current, msg=f'Asset {key} is missing') + for date_field in ['created', 'updated']: + if key in ignore: + logger.info('Ignoring key %s in asset', key) + continue + self.assertIn(date_field, current, msg=f'Asset {date_field} is missing') + self.assertTrue( + fromisoformat(current[date_field]), + msg=f"The asset field {date_field} has an invalid date" + ) + if 'links' not in ignore: + name = current['id'] + links = [ + { + 'rel': 'self', + 'href': f'{TEST_LINK_ROOT_HREF}/collections/{collection}/assets/{name}' + }, + TEST_LINK_ROOT, + { + 'rel': 'parent', + 'href': f'{TEST_LINK_ROOT_HREF}/collections/{collection}', + }, + ] + self._check_stac_links('asset.links', links, current['links']) + def _check_stac_dictsubset(self, parent_path, expected, current, ignore=None): for key, value in expected.items(): path = f'{parent_path}.{key}' diff --git a/app/tests/tests_10/data_factory.py b/app/tests/tests_10/data_factory.py index 3a463fc1..2eebaa92 100644 --- a/app/tests/tests_10/data_factory.py +++ b/app/tests/tests_10/data_factory.py @@ -801,23 +801,12 @@ class CollectionAssetSample(SampleData): samples_dict = collection_asset_samples key_mapping = { 'name': 'id', - 'eo_gsd': 'gsd', - 'geoadmin_variant': 'geoadmin:variant', - 'geoadmin_lang': 'geoadmin:lang', 'proj_epsg': 'proj:epsg', 'media_type': 'type', 'checksum_multihash': 'file:checksum', 'file': 'href' } - optional_fields = [ - 'title', - 'description', - 'eo_gsd', - 'geoadmin_variant', - 'geoadmin_lang', - 'proj_epsg', - 'checksum_multihash' - ] + optional_fields = ['title', 'description', 'proj_epsg', 'checksum_multihash'] read_only_fields = ['created', 'updated', 'href', 'file:checksum'] def __init__(self, collection, sample='asset-1', name=None, required_only=False, **kwargs): diff --git a/app/tests/tests_10/test_collection_asset_upload_endpoint.py b/app/tests/tests_10/test_collection_asset_upload_endpoint.py new file mode 100644 index 00000000..4945b1a2 --- /dev/null +++ b/app/tests/tests_10/test_collection_asset_upload_endpoint.py @@ -0,0 +1,1266 @@ +# pylint: disable=too-many-ancestors,too-many-lines +import gzip +import hashlib +import logging +from base64 import b64encode +from datetime import datetime +from urllib import parse + +from django.conf import settings +from django.contrib.auth import get_user_model +from django.test import Client + +from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload +from stac_api.utils import fromisoformat +from stac_api.utils import get_collection_asset_path +from stac_api.utils import get_s3_client +from stac_api.utils import get_sha256_multihash +from stac_api.utils import utc_aware + +from tests.tests_10.base_test import StacBaseTestCase +from tests.tests_10.base_test import StacBaseTransactionTestCase +from tests.tests_10.data_factory import Factory +from tests.tests_10.utils import reverse_version +from tests.utils import S3TestMixin +from tests.utils import client_login +from tests.utils import get_file_like_object +from tests.utils import mock_s3_asset_file + +logger = logging.getLogger(__name__) + +KB = 1024 +MB = 1024 * KB +GB = 1024 * MB + + +def base64_md5(data): + return b64encode(hashlib.md5(data).digest()).decode('utf-8') + + +def create_md5_parts(number_parts, offset, file_like): + return [{ + 'part_number': i + 1, 'md5': base64_md5(file_like[i * offset:(i + 1) * offset]) + } for i in range(number_parts)] + + +class CollectionAssetUploadBaseTest(StacBaseTestCase, S3TestMixin): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.client = Client() + client_login(self.client) + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample( + collection=self.collection, sample='asset-no-file' + ).model + self.maxDiff = None # pylint: disable=invalid-name + + def get_asset_upload_queryset(self): + return CollectionAssetUpload.objects.all().filter( + asset__collection__name=self.collection.name, + asset__name=self.asset.name, + ) + + def get_delete_asset_path(self): + return reverse_version( + 'collection-asset-detail', args=[self.collection.name, self.asset.name] + ) + + def get_get_multipart_uploads_path(self): + return reverse_version( + 'collection-asset-uploads-list', args=[self.collection.name, self.asset.name] + ) + + def get_create_multipart_upload_path(self): + return reverse_version( + 'collection-asset-uploads-list', args=[self.collection.name, self.asset.name] + ) + + def get_abort_multipart_upload_path(self, upload_id): + return reverse_version( + 'collection-asset-upload-abort', + args=[self.collection.name, self.asset.name, upload_id] + ) + + def get_complete_multipart_upload_path(self, upload_id): + return reverse_version( + 'collection-asset-upload-complete', + args=[self.collection.name, self.asset.name, upload_id] + ) + + def get_list_parts_path(self, upload_id): + return reverse_version( + 'collection-asset-upload-parts-list', + args=[self.collection.name, self.asset.name, upload_id] + ) + + def s3_upload_parts(self, upload_id, file_like, size, number_parts): + s3 = get_s3_client() + key = get_collection_asset_path(self.collection, self.asset.name) + parts = [] + # split the file into parts + start = 0 + offset = size // number_parts + for part in range(1, number_parts + 1): + # use the s3 client to upload the file instead of the presigned url due to the s3 + # mocking + response = s3.upload_part( + Body=file_like[start:start + offset], + Bucket=settings.AWS_SETTINGS['legacy']['S3_BUCKET_NAME'], + Key=key, + PartNumber=part, + UploadId=upload_id + ) + start += offset + parts.append({'etag': response['ETag'], 'part_number': part}) + return parts + + def check_urls_response(self, urls, number_parts): + now = utc_aware(datetime.utcnow()) + self.assertEqual(len(urls), number_parts) + for i, url in enumerate(urls): + self.assertListEqual( + list(url.keys()), ['url', 'part', 'expires'], msg='Url dictionary keys missing' + ) + self.assertEqual( + url['part'], i + 1, msg=f'Part {url["part"]} does not match the url index {i}' + ) + try: + url_parsed = parse.urlparse(url["url"]) + self.assertIn(url_parsed[0], ['http', 'https']) + except ValueError as error: + self.fail(msg=f"Invalid url {url['url']} for part {url['part']}: {error}") + try: + expires_dt = fromisoformat(url['expires']) + self.assertGreater( + expires_dt, + now, + msg=f"expires {url['expires']} for part {url['part']} is not in future" + ) + except ValueError as error: + self.fail(msg=f"Invalid expires {url['expires']} for part {url['part']}: {error}") + + def check_created_response(self, json_response): + self.assertNotIn('completed', json_response) + self.assertNotIn('aborted', json_response) + self.assertIn('upload_id', json_response) + self.assertIn('status', json_response) + self.assertIn('number_parts', json_response) + self.assertIn('file:checksum', json_response) + self.assertIn('urls', json_response) + self.assertEqual(json_response['status'], 'in-progress') + + def check_completed_response(self, json_response): + self.assertNotIn('urls', json_response) + self.assertNotIn('aborted', json_response) + self.assertIn('upload_id', json_response) + self.assertIn('status', json_response) + self.assertIn('number_parts', json_response) + self.assertIn('file:checksum', json_response) + self.assertIn('completed', json_response) + self.assertEqual(json_response['status'], 'completed') + self.assertGreater( + fromisoformat(json_response['completed']), fromisoformat(json_response['created']) + ) + + def check_aborted_response(self, json_response): + self.assertNotIn('urls', json_response) + self.assertNotIn('completed', json_response) + self.assertIn('upload_id', json_response) + self.assertIn('status', json_response) + self.assertIn('number_parts', json_response) + self.assertIn('file:checksum', json_response) + self.assertIn('aborted', json_response) + self.assertEqual(json_response['status'], 'aborted') + self.assertGreater( + fromisoformat(json_response['aborted']), fromisoformat(json_response['created']) + ) + + +class CollectionAssetUploadCreateEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_asset_upload_create_abort_multipart(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + file_like, checksum_multihash = get_file_like_object(1 * KB) + offset = 1 * KB // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + + self.check_urls_response(json_data['urls'], number_parts) + + response = self.client.post( + self.get_abort_multipart_upload_path(json_data['upload_id']), + data={}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + json_data = response.json() + self.check_aborted_response(json_data) + self.assertFalse( + self.get_asset_upload_queryset().filter( + status=CollectionAssetUpload.Status.IN_PROGRESS + ).exists(), + msg='In progress upload found' + ) + self.assertTrue( + self.get_asset_upload_queryset().filter(status=CollectionAssetUpload.Status.ABORTED + ).exists(), + msg='Aborted upload not found' + ) + # check that there is only one multipart upload on S3 + s3 = get_s3_client() + response = s3.list_multipart_uploads( + Bucket=settings.AWS_SETTINGS['legacy']['S3_BUCKET_NAME'], KeyMarker=key + ) + self.assertNotIn('Uploads', response, msg='uploads found on S3') + + def test_asset_upload_create_multipart_duplicate(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + file_like, checksum_multihash = get_file_like_object(1 * KB) + offset = 1 * KB // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + initial_upload_id = json_data['upload_id'] + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(409, response) + self.assertEqual(response.json()['description'], "Upload already in progress") + self.assertIn( + "upload_id", + response.json(), + msg="The upload id of the current upload is missing from response" + ) + self.assertEqual( + initial_upload_id, + response.json()['upload_id'], + msg="Current upload ID not matching the one from the 409 Conflict response" + ) + + self.assertEqual( + self.get_asset_upload_queryset().filter( + status=CollectionAssetUpload.Status.IN_PROGRESS + ).count(), + 1, + msg='More than one upload in progress' + ) + + # check that there is only one multipart upload on S3 + s3 = get_s3_client() + response = s3.list_multipart_uploads( + Bucket=settings.AWS_SETTINGS['legacy']['S3_BUCKET_NAME'], KeyMarker=key + ) + self.assertIn('Uploads', response, msg='Failed to retrieve the upload list from s3') + self.assertEqual(len(response['Uploads']), 1, msg='More or less uploads found on S3') + + +class CollectionAssetUploadCreateRaceConditionTest(StacBaseTransactionTestCase, S3TestMixin): + + @mock_s3_asset_file + def setUp(self): + self.username = 'user' + self.password = 'dummy-password' + get_user_model().objects.create_superuser(self.username, password=self.password) + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample( + collection=self.collection, sample='asset-no-file' + ).model + + def test_asset_upload_create_race_condition(self): + workers = 5 + + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + file_like, checksum_multihash = get_file_like_object(1 * KB) + offset = 1 * KB // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + path = reverse_version( + 'collection-asset-uploads-list', args=[self.collection.name, self.asset.name] + ) + + def asset_upload_atomic_create_test(worker): + # This method run on separate thread therefore it requires to create a new client and + # to login it for each call. + client = Client() + client.login(username=self.username, password=self.password) + return client.post( + path, + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + + # We call the POST asset upload several times in parallel with the same data to make sure + # that we don't have any race condition. + results, errors = self.run_parallel(workers, asset_upload_atomic_create_test) + + for _, response in results: + self.assertStatusCode([201, 409], response) + + ok_201 = [r for _, r in results if r.status_code == 201] + bad_409 = [r for _, r in results if r.status_code == 409] + self.assertEqual(len(ok_201), 1, msg="More than 1 parallel request was successful") + for response in bad_409: + self.assertEqual(response.json()['description'], "Upload already in progress") + + +class CollectionAssetUpload1PartEndpointTestCase(CollectionAssetUploadBaseTest): + + def upload_asset_with_dyn_cache(self, update_interval=None): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + md5_parts = [{'part_number': 1, 'md5': base64_md5(file_like)}] + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash, + 'update_interval': update_interval + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + self.assertIn('md5_parts', json_data) + self.assertEqual(json_data['md5_parts'], md5_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + return key + + def test_asset_upload_1_part_md5_integrity(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + md5_parts = [{'part_number': 1, 'md5': base64_md5(file_like)}] + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + self.assertIn('md5_parts', json_data) + self.assertEqual(json_data['md5_parts'], md5_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectCacheControl(obj, key, max_age=7200) + self.assertS3ObjectContentEncoding(obj, key, None) + + def test_asset_upload_dyn_cache(self): + key = self.upload_asset_with_dyn_cache(update_interval=600) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectCacheControl(obj, key, max_age=8) + + def test_asset_upload_no_cache(self): + key = self.upload_asset_with_dyn_cache(update_interval=5) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectCacheControl(obj, key, no_cache=True) + + def test_asset_upload_no_content_encoding(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + md5_parts = [{'part_number': 1, 'md5': base64_md5(file_like)}] + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + self.assertIn('md5_parts', json_data) + self.assertEqual(json_data['md5_parts'], md5_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectContentEncoding(obj, key, None) + + def test_asset_upload_gzip(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * MB + file_like, checksum_multihash = get_file_like_object(size) + file_like_compress = gzip.compress(file_like) + size_compress = len(file_like_compress) + md5_parts = [{'part_number': 1, 'md5': base64_md5(file_like_compress)}] + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash, + 'content_encoding': 'gzip' + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + self.assertIn('md5_parts', json_data) + self.assertEqual(json_data['md5_parts'], md5_parts) + + parts = self.s3_upload_parts( + json_data['upload_id'], file_like_compress, size_compress, number_parts + ) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + self.assertS3ObjectExists(key) + obj = self.get_s3_object(key) + self.assertS3ObjectContentEncoding(obj, key, encoding='gzip') + + +class CollectionAssetUpload2PartEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_asset_upload_2_parts_md5_integrity(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + size = 10 * MB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_created_response(json_data) + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.check_completed_response(response.json()) + self.assertS3ObjectExists(key) + + +class CollectionAssetUploadInvalidEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_asset_upload_invalid_content_encoding(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + size = 10 * MB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'md5_parts': md5_parts, + 'file:checksum': checksum_multihash, + 'content_encoding': 'hello world' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + {'content_encoding': ['Invalid encoding "hello world": must be one of ' + '"br, gzip"']} + ) + + def test_asset_upload_1_part_no_md5(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'md5_parts': ['This field is required.']}) + + def test_asset_upload_2_parts_no_md5(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + size = 10 * MB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, 'file:checksum': checksum_multihash + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'md5_parts': ['This field is required.']}) + + def test_asset_upload_create_empty_payload(self): + response = self.client.post( + self.get_create_multipart_upload_path(), data={}, content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'file:checksum': ['This field is required.'], + 'number_parts': ['This field is required.'], + 'md5_parts': ['This field is required.'] + } + ) + + def test_asset_upload_create_invalid_data(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 0, + "file:checksum": 'abcdef', + "md5_parts": [{ + "part_number": '0', "md5": 'abcdef' + }] + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'file:checksum': ['Invalid multihash value; Invalid varint provided'], + 'number_parts': ['Ensure this value is greater than or equal to 1.'] + } + ) + + def test_asset_upload_create_too_many_parts(self): + + number_parts = 101 + md5_parts = [{'part_number': i + 1, 'md5': 'abcdef'} for i in range(number_parts)] + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 101, "file:checksum": 'abcdef', 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'file:checksum': ['Invalid multihash value; Invalid varint provided'], + 'number_parts': ['Ensure this value is less than or equal to 100.'] + } + ) + + def test_asset_upload_create_empty_md5_parts(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 2, + "md5_parts": [], + "file:checksum": + '12200ADEC47F803A8CF1055ED36750B3BA573C79A3AF7DA6D6F5A2AED03EA16AF3BC' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'non_field_errors': [ + 'Missing, too many or duplicate part_number in md5_parts field list: ' + 'list should have 2 item(s).' + ] + } + ) + + def test_asset_upload_create_duplicate_md5_parts(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 3, + "md5_parts": [{ + 'part_number': 1, 'md5': 'asdf' + }, { + 'part_number': 1, 'md5': 'asdf' + }, { + 'part_number': 2, 'md5': 'asdf' + }], + "file:checksum": + '12200ADEC47F803A8CF1055ED36750B3BA573C79A3AF7DA6D6F5A2AED03EA16AF3BC' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'non_field_errors': [ + 'Missing, too many or duplicate part_number in md5_parts field list: ' + 'list should have 3 item(s).' + ] + } + ) + + def test_asset_upload_create_too_many_md5_parts(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 2, + "md5_parts": [{ + 'part_number': 1, 'md5': 'asdf' + }, { + 'part_number': 2, 'md5': 'asdf' + }, { + 'part_number': 3, 'md5': 'asdf' + }], + "file:checksum": + '12200ADEC47F803A8CF1055ED36750B3BA573C79A3AF7DA6D6F5A2AED03EA16AF3BC' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'non_field_errors': [ + 'Missing, too many or duplicate part_number in md5_parts field list: ' + 'list should have 2 item(s).' + ] + } + ) + + def test_asset_upload_create_md5_parts_missing_part_number(self): + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': 2, + "md5_parts": [ + { + 'part_number': 1, 'md5': 'asdf' + }, + { + 'md5': 'asdf' + }, + ], + "file:checksum": + '12200ADEC47F803A8CF1055ED36750B3BA573C79A3AF7DA6D6F5A2AED03EA16AF3BC' + }, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + {'non_field_errors': ['Invalid md5_parts[1] value: part_number field missing']} + ) + + def test_asset_upload_2_parts_too_small(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 2 + size = 1 * KB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + [ + 'An error occurred (EntityTooSmall) when calling the CompleteMultipartUpload ' + 'operation: Your proposed upload is smaller than the minimum allowed object size.' + ] + ) + self.assertS3ObjectNotExists(key) + + def test_asset_upload_1_parts_invalid_etag(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': [{ + 'etag': 'dummy', 'part_number': 1 + }]}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + [ + 'An error occurred (InvalidPart) when calling the CompleteMultipartUpload ' + 'operation: One or more of the specified parts could not be found. The part ' + 'might not have been uploaded, or the specified entity tag might not have ' + "matched the part's entity tag." + ] + ) + self.assertS3ObjectNotExists(key) + + def test_asset_upload_1_parts_too_many_parts_in_complete(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + parts.append({'etag': 'dummy', 'number_part': 2}) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'parts': ['Too many parts']}) + self.assertS3ObjectNotExists(key) + + def test_asset_upload_2_parts_incomplete_upload(self): + number_parts = 2 + size = 10 * MB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size // 2, 1) + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'parts': ['Too few parts']}) + + def test_asset_upload_1_parts_invalid_complete(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'parts': 'Missing required field'}) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': []}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual(response.json()['description'], {'parts': ['This list may not be empty.']}) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': ["dummy-etag"]}, + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + { + 'parts': { + '0': { + 'non_field_errors': + ['Invalid data. Expected a dictionary, ' + 'but got str.'] + } + } + } + ) + self.assertS3ObjectNotExists(key) + + def test_asset_upload_1_parts_duplicate_complete(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 1 + size = 1 * KB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + self.check_urls_response(json_data['urls'], number_parts) + + parts = self.s3_upload_parts(json_data['upload_id'], file_like, size, number_parts) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + + response = self.client.post( + self.get_complete_multipart_upload_path(json_data['upload_id']), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(409, response) + self.assertEqual(response.json()['code'], 409) + self.assertEqual(response.json()['description'], 'No upload in progress') + + +class CollectionAssetUploadDeleteInProgressEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_delete_asset_upload_in_progress(self): + number_parts = 2 + size = 10 * MB # Minimum upload part on S3 is 5 MB + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + upload_id = response.json()['upload_id'] + + response = self.client.delete(self.get_delete_asset_path()) + self.assertStatusCode(400, response) + self.assertEqual( + response.json()['description'], + ['Collection Asset collection-asset-1.tiff has still an upload in progress'] + ) + + self.assertTrue( + CollectionAsset.objects.all().filter( + collection__name=self.collection.name, name=self.asset.name + ).exists(), + msg='Collection Asset has been deleted' + ) + + response = self.client.post(self.get_abort_multipart_upload_path(upload_id)) + self.assertStatusCode(200, response) + + response = self.client.delete(self.get_delete_asset_path()) + self.assertStatusCode(200, response) + + self.assertFalse( + CollectionAsset.objects.all().filter( + collection__name=self.collection.name, name=self.asset.name + ).exists(), + msg='Collection Asset has not been deleted' + ) + + +class GetCollectionAssetUploadsEndpointTestCase(CollectionAssetUploadBaseTest): + + def create_dummies_uploads(self): + # Create some asset uploads + for i in range(1, 4): + CollectionAssetUpload.objects.create( + asset=self.asset, + upload_id=f'upload-{i}', + status=CollectionAssetUpload.Status.ABORTED, + checksum_multihash=get_sha256_multihash(b'upload-%d' % i), + number_parts=2, + ended=utc_aware(datetime.utcnow()), + md5_parts=[] + ) + for i in range(4, 8): + CollectionAssetUpload.objects.create( + asset=self.asset, + upload_id=f'upload-{i}', + status=CollectionAssetUpload.Status.COMPLETED, + checksum_multihash=get_sha256_multihash(b'upload-%d' % i), + number_parts=2, + ended=utc_aware(datetime.utcnow()), + md5_parts=[] + ) + CollectionAssetUpload.objects.create( + asset=self.asset, + upload_id='upload-8', + status=CollectionAssetUpload.Status.IN_PROGRESS, + checksum_multihash=get_sha256_multihash(b'upload-8'), + number_parts=2, + md5_parts=[] + ) + self.maxDiff = None # pylint: disable=invalid-name + + def test_get_asset_uploads(self): + self.create_dummies_uploads() + response = self.client.get(self.get_get_multipart_uploads_path()) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('links', json_data) + self.assertEqual(json_data['links'], []) + self.assertIn('uploads', json_data) + self.assertEqual(len(json_data['uploads']), self.get_asset_upload_queryset().count()) + self.assertEqual( + [ + 'upload_id', + 'status', + 'created', + 'aborted', + 'number_parts', + 'update_interval', + 'content_encoding', + 'file:checksum' + ], + list(json_data['uploads'][0].keys()), + ) + self.assertEqual( + [ + 'upload_id', + 'status', + 'created', + 'completed', + 'number_parts', + 'update_interval', + 'content_encoding', + 'file:checksum' + ], + list(json_data['uploads'][4].keys()), + ) + self.assertEqual( + [ + 'upload_id', + 'status', + 'created', + 'number_parts', + 'update_interval', + 'content_encoding', + 'file:checksum' + ], + list(json_data['uploads'][7].keys()), + ) + + def test_get_asset_uploads_with_content_encoding(self): + CollectionAssetUpload.objects.create( + asset=self.asset, + upload_id='upload-content-encoding', + status=CollectionAssetUpload.Status.COMPLETED, + checksum_multihash=get_sha256_multihash(b'upload-content-encoding'), + number_parts=2, + ended=utc_aware(datetime.utcnow()), + md5_parts=[], + content_encoding='gzip' + ) + response = self.client.get(self.get_get_multipart_uploads_path()) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('links', json_data) + self.assertEqual(json_data['links'], []) + self.assertIn('uploads', json_data) + self.assertEqual(len(json_data['uploads']), self.get_asset_upload_queryset().count()) + self.assertEqual( + [ + 'upload_id', + 'status', + 'created', + 'completed', + 'number_parts', + 'update_interval', + 'content_encoding', + 'file:checksum' + ], + list(json_data['uploads'][0].keys()), + ) + self.assertEqual('gzip', json_data['uploads'][0]['content_encoding']) + + def test_get_asset_uploads_status_query(self): + response = self.client.get( + self.get_get_multipart_uploads_path(), {'status': CollectionAssetUpload.Status.ABORTED} + ) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('uploads', json_data) + self.assertGreater(len(json_data), 1) + self.assertEqual( + len(json_data['uploads']), + self.get_asset_upload_queryset().filter(status=CollectionAssetUpload.Status.ABORTED + ).count(), + ) + for upload in json_data['uploads']: + self.assertEqual(upload['status'], CollectionAssetUpload.Status.ABORTED) + + +class CollectionAssetUploadListPartsEndpointTestCase(CollectionAssetUploadBaseTest): + + def test_asset_upload_list_parts(self): + key = get_collection_asset_path(self.collection, self.asset.name) + self.assertS3ObjectNotExists(key) + number_parts = 4 + size = 5 * MB * number_parts + file_like, checksum_multihash = get_file_like_object(size) + offset = size // number_parts + md5_parts = create_md5_parts(number_parts, offset, file_like) + response = self.client.post( + self.get_create_multipart_upload_path(), + data={ + 'number_parts': number_parts, + 'file:checksum': checksum_multihash, + 'md5_parts': md5_parts + }, + content_type="application/json" + ) + self.assertStatusCode(201, response) + json_data = response.json() + upload_id = json_data['upload_id'] + self.check_urls_response(json_data['urls'], number_parts) + + # List the uploaded parts should be empty + response = self.client.get(self.get_list_parts_path(upload_id)) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('links', json_data, msg='missing required field in list parts response') + self.assertIn('parts', json_data, msg='missing required field in list parts response') + self.assertEqual(len(json_data['parts']), 0, msg='parts should be empty') + + # upload all the parts + parts = self.s3_upload_parts(upload_id, file_like, size, number_parts) + + # List the uploaded parts should have 4 parts + response = self.client.get(self.get_list_parts_path(upload_id)) + self.assertStatusCode(200, response) + json_data = response.json() + self.assertIn('links', json_data, msg='missing required field in list parts response') + self.assertIn('parts', json_data, msg='missing required field in list parts response') + self.assertEqual(len(json_data['parts']), number_parts) + for part in json_data['parts']: + self.assertIn('etag', part) + self.assertIn('modified', part) + self.assertIn('size', part) + self.assertIn('part_number', part) + + # Unfortunately moto doesn't support yet the MaxParts + # (see https://github.com/spulec/moto/issues/2680) + # Test the list parts pagination + # response = self.client.get(self.get_list_parts_path(upload_id), {'limit': 2}) + # self.assertStatusCode(200, response) + # json_data = response.json() + # self.assertIn('links', json_data, msg='missing required field in list parts response') + # self.assertIn('parts', json_data, msg='missing required field in list parts response') + # self.assertEqual(len(json_data['parts']), number_parts) + # for part in json_data['parts']: + # self.assertIn('etag', part) + # self.assertIn('modified', part) + # self.assertIn('size', part) + # self.assertIn('part_number', part) + + # Complete the upload + response = self.client.post( + self.get_complete_multipart_upload_path(upload_id), + data={'parts': parts}, + content_type="application/json" + ) + self.assertStatusCode(200, response) + self.assertS3ObjectExists(key) diff --git a/app/tests/tests_10/test_collection_asset_upload_model.py b/app/tests/tests_10/test_collection_asset_upload_model.py new file mode 100644 index 00000000..9ea016aa --- /dev/null +++ b/app/tests/tests_10/test_collection_asset_upload_model.py @@ -0,0 +1,212 @@ +import logging +from datetime import datetime + +from django.core.exceptions import ValidationError +from django.db.models import ProtectedError +from django.test import TestCase +from django.test import TransactionTestCase + +from stac_api.models import CollectionAsset +from stac_api.models import CollectionAssetUpload +from stac_api.utils import get_sha256_multihash +from stac_api.utils import utc_aware + +from tests.tests_10.data_factory import Factory +from tests.utils import mock_s3_asset_file + +logger = logging.getLogger(__name__) + + +class CollectionAssetUploadTestCaseMixin: + + def create_asset_upload(self, asset, upload_id, **kwargs): + asset_upload = CollectionAssetUpload( + asset=asset, + upload_id=upload_id, + checksum_multihash=get_sha256_multihash(b'Test'), + number_parts=1, + md5_parts=["this is an md5 value"], + **kwargs + ) + asset_upload.full_clean() + asset_upload.save() + self.assertEqual( + asset_upload, + CollectionAssetUpload.objects.get( + upload_id=upload_id, + asset__name=asset.name, + asset__collection__name=asset.collection.name + ) + ) + return asset_upload + + def update_asset_upload(self, asset_upload, **kwargs): + for kwarg, value in kwargs.items(): + setattr(asset_upload, kwarg, value) + asset_upload.full_clean() + asset_upload.save() + asset_upload.refresh_from_db() + self.assertEqual( + asset_upload, + CollectionAssetUpload.objects.get( + upload_id=asset_upload.upload_id, asset__name=asset_upload.asset.name + ) + ) + return asset_upload + + def check_etag(self, etag): + self.assertIsInstance(etag, str, msg="Etag must be a string") + self.assertNotEqual(etag, '', msg='Etag should not be empty') + + +class CollectionAssetUploadModelTestCase(TestCase, CollectionAssetUploadTestCaseMixin): + + @classmethod + @mock_s3_asset_file + def setUpTestData(cls): + cls.factory = Factory() + cls.collection = cls.factory.create_collection_sample().model + cls.asset_1 = cls.factory.create_collection_asset_sample(collection=cls.collection).model + cls.asset_2 = cls.factory.create_collection_asset_sample(collection=cls.collection).model + + def test_create_asset_upload_default(self): + asset_upload = self.create_asset_upload(self.asset_1, 'default-upload') + self.assertEqual(asset_upload.urls, [], msg="Wrong default value") + self.assertEqual(asset_upload.ended, None, msg="Wrong default value") + self.assertAlmostEqual( + utc_aware(datetime.utcnow()).timestamp(), + asset_upload.created.timestamp(), # pylint: disable=no-member + delta=1, + msg="Wrong default value" + ) + + def test_unique_constraint(self): + # Check that asset upload is unique in collection/asset + # therefore the following asset upload should be ok + # collection-1/asset-1/default-upload + # collection-2/asset-1/default-upload + collection_2 = self.factory.create_collection_sample().model + asset_2 = self.factory.create_collection_asset_sample( + collection_2, name=self.asset_1.name + ).model + asset_upload_1 = self.create_asset_upload(self.asset_1, 'default-upload') + asset_upload_2 = self.create_asset_upload(asset_2, 'default-upload') + self.assertEqual(asset_upload_1.upload_id, asset_upload_2.upload_id) + self.assertEqual(asset_upload_1.asset.name, asset_upload_2.asset.name) + self.assertNotEqual( + asset_upload_1.asset.collection.name, asset_upload_2.asset.collection.name + ) + # But duplicate path are not allowed + with self.assertRaises(ValidationError, msg="Existing asset upload could be re-created."): + asset_upload_3 = self.create_asset_upload(self.asset_1, 'default-upload') + + def test_create_asset_upload_duplicate_in_progress(self): + # create a first upload on asset 1 + asset_upload_1 = self.create_asset_upload(self.asset_1, '1st-upload') + + # create a first upload on asset 2 + asset_upload_2 = self.create_asset_upload(self.asset_2, '1st-upload') + + # create a second upload on asset 1 should not be allowed. + with self.assertRaises( + ValidationError, msg="Existing asset upload already in progress could be re-created." + ): + asset_upload_3 = self.create_asset_upload(self.asset_1, '2nd-upload') + + def test_asset_upload_etag(self): + asset_upload = self.create_asset_upload(self.asset_1, 'default-upload') + original_etag = asset_upload.etag + self.check_etag(original_etag) + asset_upload = self.update_asset_upload( + asset_upload, status=CollectionAssetUpload.Status.ABORTED + ) + self.check_etag(asset_upload.etag) + self.assertNotEqual(asset_upload.etag, original_etag, msg='Etag was not updated') + + def test_asset_upload_invalid_number_parts(self): + with self.assertRaises(ValidationError): + asset_upload = CollectionAssetUpload( + asset=self.asset_1, + upload_id='my-upload-id', + checksum_multihash=get_sha256_multihash(b'Test'), + number_parts=-1, + md5_parts=['fake_md5'] + ) + asset_upload.full_clean() + asset_upload.save() + + +class CollectionAssetUploadDeleteProtectModelTestCase( + TransactionTestCase, CollectionAssetUploadTestCaseMixin +): + + @mock_s3_asset_file + def setUp(self): + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample(collection=self.collection).model + + def test_delete_asset_upload(self): + upload_id = 'upload-in-progress' + asset_upload = self.create_asset_upload(self.asset, upload_id) + + with self.assertRaises(ProtectedError, msg="Deleting an upload in progress not allowed"): + asset_upload.delete() + + asset_upload = self.update_asset_upload( + asset_upload, + status=CollectionAssetUpload.Status.COMPLETED, + ended=utc_aware(datetime.utcnow()) + ) + + asset_upload.delete() + self.assertFalse( + CollectionAssetUpload.objects.all().filter( + upload_id=upload_id, asset__name=self.asset.name + ).exists() + ) + + def test_delete_asset_with_upload_in_progress(self): + asset_upload_1 = self.create_asset_upload(self.asset, 'upload-in-progress') + asset_upload_2 = self.create_asset_upload( + self.asset, + 'upload-completed', + status=CollectionAssetUpload.Status.COMPLETED, + ended=utc_aware(datetime.utcnow()) + ) + asset_upload_3 = self.create_asset_upload( + self.asset, + 'upload-aborted', + status=CollectionAssetUpload.Status.ABORTED, + ended=utc_aware(datetime.utcnow()) + ) + asset_upload_4 = self.create_asset_upload( + self.asset, + 'upload-aborted-2', + status=CollectionAssetUpload.Status.ABORTED, + ended=utc_aware(datetime.utcnow()) + ) + + # Try to delete parent asset + with self.assertRaises(ValidationError): + self.asset.delete() + self.assertEqual(4, len(list(CollectionAssetUpload.objects.all()))) + self.assertTrue( + CollectionAsset.objects.all().filter( + name=self.asset.name, collection__name=self.collection.name + ).exists() + ) + + self.update_asset_upload( + asset_upload_1, + status=CollectionAssetUpload.Status.ABORTED, + ended=utc_aware(datetime.utcnow()) + ) + + self.asset.delete() + self.assertEqual(0, len(list(CollectionAssetUpload.objects.all()))) + self.assertFalse( + CollectionAsset.objects.all().filter( + name=self.asset.name, collection__name=self.collection.name + ).exists() + ) diff --git a/app/tests/tests_10/test_collection_assets_endpoint.py b/app/tests/tests_10/test_collection_assets_endpoint.py new file mode 100644 index 00000000..1596a8d9 --- /dev/null +++ b/app/tests/tests_10/test_collection_assets_endpoint.py @@ -0,0 +1,729 @@ +import logging +from datetime import datetime +from json import dumps +from json import loads +from pprint import pformat + +from django.contrib.auth import get_user_model +from django.test import Client +from django.urls import reverse + +from stac_api.models import CollectionAsset +from stac_api.utils import get_collection_asset_path +from stac_api.utils import utc_aware + +from tests.tests_10.base_test import STAC_BASE_V +from tests.tests_10.base_test import StacBaseTestCase +from tests.tests_10.base_test import StacBaseTransactionTestCase +from tests.tests_10.data_factory import Factory +from tests.tests_10.utils import reverse_version +from tests.utils import S3TestMixin +from tests.utils import client_login +from tests.utils import disableLogger +from tests.utils import mock_s3_asset_file + +logger = logging.getLogger(__name__) + + +def to_dict(input_ordered_dict): + return loads(dumps(input_ordered_dict)) + + +class CollectionAssetsEndpointTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.client = Client() + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset_1 = self.factory.create_collection_asset_sample( + collection=self.collection, name="asset-1.tiff", db_create=True + ) + self.maxDiff = None # pylint: disable=invalid-name + + def test_assets_endpoint(self): + collection_name = self.collection.name + # To test the assert ordering make sure to not create them in ascent order + asset_2 = self.factory.create_collection_asset_sample( + collection=self.collection, sample='asset-2', name="asset-2.txt", db_create=True + ) + asset_3 = self.factory.create_collection_asset_sample( + collection=self.collection, name="asset-0.pdf", sample='asset-3', db_create=True + ) + response = self.client.get(f"/{STAC_BASE_V}/collections/{collection_name}/assets") + self.assertStatusCode(200, response) + json_data = response.json() + logger.debug('Response (%s):\n%s', type(json_data), pformat(json_data)) + + self.assertIn('assets', json_data, msg='assets is missing in response') + self.assertEqual( + 3, len(json_data['assets']), msg='Number of assets doesn\'t match the expected' + ) + + # Check that the output is sorted by name + asset_ids = [asset['id'] for asset in json_data['assets']] + self.assertListEqual(asset_ids, sorted(asset_ids), msg="Assets are not sorted by ID") + + asset_samples = sorted([self.asset_1, asset_2, asset_3], key=lambda asset: asset['name']) + for i, asset in enumerate(asset_samples): + # self.check_stac_asset(asset.json, json_data['assets'][i], collection_name, item_name) + self.check_stac_collection_asset(asset.json, json_data['assets'][i], collection_name) + + def test_assets_endpoint_collection_does_not_exist(self): + collection_name = "non-existent" + response = self.client.get(f"/{STAC_BASE_V}/collections/{collection_name}/assets") + self.assertStatusCode(404, response) + + def test_single_asset_endpoint(self): + collection_name = self.collection.name + asset_name = self.asset_1["name"] + response = self.client.get( + f"/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}" + ) + json_data = response.json() + self.assertStatusCode(200, response) + logger.debug('Response (%s):\n%s', type(json_data), pformat(json_data)) + + self.check_stac_collection_asset(self.asset_1.json, json_data, collection_name) + + # The ETag change between each test call due to the created, updated time that are in the + # hash computation of the ETag + self.assertEtagHeader(None, response) + + +class CollectionAssetsUnimplementedEndpointTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_unimplemented_post(self): + collection_name = self.collection.name + asset = self.factory.create_collection_asset_sample( + collection=self.collection, required_only=True + ) + response = self.client.post( + f'/{STAC_BASE_V}/collections/{collection_name}/assets', + data=asset.get_json('post'), + content_type="application/json" + ) + self.assertStatusCode(405, response) + + +class CollectionAssetsCreateEndpointTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_upsert_create_only_required(self): + collection_name = self.collection.name + asset = self.factory.create_collection_asset_sample( + collection=self.collection, required_only=True + ) + path = \ + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset["name"]}' + json_to_send = asset.get_json('put') + # Send a non normalized form of the type to see if it is also accepted + json_to_send['type'] = 'image/TIFF;application=geotiff; Profile=cloud-optimized' + response = self.client.put(path, data=json_to_send, content_type="application/json") + json_data = response.json() + self.assertStatusCode(201, response) + self.assertLocationHeader(f"{path}", response) + self.check_stac_collection_asset(asset.json, json_data, collection_name) + + # Check the data by reading it back + response = self.client.get(response['Location']) + json_data = response.json() + self.assertStatusCode(200, response) + self.check_stac_collection_asset(asset.json, json_data, collection_name) + + # make sure that the optional fields are not present + self.assertNotIn('proj:epsg', json_data) + self.assertNotIn('description', json_data) + self.assertNotIn('title', json_data) + self.assertNotIn('file:checksum', json_data) + + def test_asset_upsert_create(self): + collection = self.collection + asset = self.factory.create_collection_asset_sample( + collection=self.collection, sample='asset-no-checksum', create_asset_file=False + ) + asset_name = asset['name'] + + response = self.client.get( + reverse_version('collection-asset-detail', args=[collection.name, asset_name]) + ) + # Check that assert does not exist already + self.assertStatusCode(404, response) + + # Check also, that the asset does not exist in the DB already + self.assertFalse( + CollectionAsset.objects.filter(name=asset_name).exists(), + msg="Collection asset already exists" + ) + + # Now use upsert to create the new asset + response = self.client.put( + reverse_version('collection-asset-detail', args=[collection.name, asset_name]), + data=asset.get_json('put'), + content_type="application/json" + ) + json_data = response.json() + self.assertStatusCode(201, response) + self.assertLocationHeader( + reverse_version('collection-asset-detail', args=[collection.name, asset_name]), + response + ) + self.check_stac_collection_asset(asset.json, json_data, collection.name) + + # make sure that all optional fields are present + self.assertIn('proj:epsg', json_data) + self.assertIn('description', json_data) + self.assertIn('title', json_data) + + # Checksum multihash is set by the AssetUpload later on + self.assertNotIn('file:checksum', json_data) + + # Check the data by reading it back + response = self.client.get(response['Location']) + json_data = response.json() + self.assertStatusCode(200, response) + self.check_stac_collection_asset(asset.json, json_data, collection.name) + + def test_asset_upsert_create_non_existing_parent_collection_in_path(self): + asset = self.factory.create_collection_asset_sample( + collection=self.collection, create_asset_file=False + ) + asset_name = asset['name'] + + path = (f'/{STAC_BASE_V}/collections/non-existing-collection/assets/' + f'{asset_name}') + + # Check that asset does not exist already + response = self.client.get(path) + self.assertStatusCode(404, response) + + # Check also, that the asset does not exist in the DB already + self.assertFalse( + CollectionAsset.objects.filter(name=asset_name).exists(), + msg="Deleted colelction asset still found in DB" + ) + + # Now use upsert to create the new asset + response = self.client.put( + path, data=asset.get_json('post'), content_type="application/json" + ) + self.assertStatusCode(404, response) + + def test_asset_upsert_create_empty_string(self): + collection_name = self.collection.name + asset = self.factory.create_collection_asset_sample( + collection=self.collection, required_only=True, description='', title='' + ) + + path = \ + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset["name"]}' + response = self.client.put( + path, data=asset.get_json('put'), content_type="application/json" + ) + self.assertStatusCode(400, response) + json_data = response.json() + for field in ['description', 'title']: + self.assertIn(field, json_data['description'], msg=f'Field {field} error missing') + + def invalid_request_wrapper(self, sample_name, expected_error_messages, **extra_params): + collection_name = self.collection.name + asset = self.factory.create_collection_asset_sample( + collection=self.collection, sample=sample_name, **extra_params + ) + + path = \ + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset["name"]}' + response = self.client.put( + path, data=asset.get_json('put'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual( + expected_error_messages, + response.json()['description'], + msg='Unexpected error message', + ) + + # Make sure that the asset is not found in DB + self.assertFalse( + CollectionAsset.objects.filter(name=asset.json['id']).exists(), + msg="Invalid asset has been created in DB" + ) + + def test_asset_upsert_create_invalid_data(self): + self.invalid_request_wrapper( + 'asset-invalid', { + 'proj:epsg': ['A valid integer is required.'], + 'type': ['Invalid media type "dummy"'] + } + ) + + def test_asset_upsert_create_invalid_type(self): + media_type_str = "image/tiff; application=Geotiff; profile=cloud-optimized" + self.invalid_request_wrapper( + 'asset-invalid-type', {'type': [f'Invalid media type "{media_type_str}"']} + ) + + def test_asset_upsert_create_type_extension_mismatch(self): + media_type_str = "application/gml+xml" + self.invalid_request_wrapper( + 'asset-invalid-type', + { + 'non_field_errors': [ + f"Invalid id extension '.tiff', id must match its media type {media_type_str}" + ] + }, + media_type=media_type_str, + # must be overridden, else extension will automatically match the type + name='asset-invalid-type.tiff' + ) + + +class CollectionAssetsUpdateEndpointAssetFileTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample(db_create=True) + self.asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, db_create=True + ) + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_endpoint_patch_put_href(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + asset_sample = self.asset.copy() + + put_payload = asset_sample.get_json('put') + put_payload['href'] = 'https://testserver/non-existing-asset' + patch_payload = {'href': 'https://testserver/non-existing-asset'} + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch(path, data=patch_payload, content_type="application/json") + self.assertStatusCode(400, response) + description = response.json()['description'] + self.assertIn('href', description, msg=f'Unexpected field error {description}') + self.assertEqual( + "Found read-only property in payload", + description['href'][0], + msg="Unexpected error message" + ) + + response = self.client.put(path, data=put_payload, content_type="application/json") + self.assertStatusCode(400, response) + description = response.json()['description'] + self.assertIn('href', description, msg=f'Unexpected field error {description}') + self.assertEqual( + "Found read-only property in payload", + description['href'][0], + msg="Unexpected error message" + ) + + +class CollectionAssetsUpdateEndpointTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample(db_create=True) + self.asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, db_create=True + ) + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_endpoint_put(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + checksum_multihash=self.asset['checksum_multihash'], + create_asset_file=False + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put( + path, data=changed_asset.get_json('put'), content_type="application/json" + ) + json_data = response.json() + self.assertStatusCode(200, response) + self.check_stac_collection_asset(changed_asset.json, json_data, collection_name) + + # Check the data by reading it back + response = self.client.get(path) + json_data = response.json() + self.assertStatusCode(200, response) + self.check_stac_collection_asset(changed_asset.json, json_data, collection_name) + + def test_asset_endpoint_put_extra_payload(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + checksum_multihash=self.asset['checksum_multihash'], + extra_attribute='not allowed', + create_asset_file=False + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put( + path, data=changed_asset.get_json('put'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'extra_attribute': ['Unexpected property in payload']}, + response.json()['description'], + msg='Unexpected error message') + + def test_asset_endpoint_put_read_only_in_payload(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + created=utc_aware(datetime.utcnow()), + create_asset_file=False, + checksum_multihash=self.asset['checksum_multihash'], + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put( + path, + data=changed_asset.get_json('put', keep_read_only=True), + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({ + 'created': ['Found read-only property in payload'], + 'file:checksum': ['Found read-only property in payload'] + }, + response.json()['description'], + msg='Unexpected error message') + + def test_asset_endpoint_put_rename_asset(self): + # rename should not be allowed + collection_name = self.collection['name'] + asset_name = self.asset['name'] + new_asset_name = "new-asset-name.txt" + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=new_asset_name, + sample='asset-1-updated', + checksum_multihash=self.asset['checksum_multihash'] + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put( + path, data=changed_asset.get_json('put'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'id': 'Renaming is not allowed'}, + response.json()['description'], + msg='Unexpected error message') + + # Check the data by reading it back + response = self.client.get( + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + ) + json_data = response.json() + self.assertStatusCode(200, response) + + self.assertEqual(asset_name, json_data['id']) + + # Check the data that no new entry exist + response = self.client.get( + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{new_asset_name}' + ) + + # 404 - not found + self.assertStatusCode(404, response) + + def test_asset_endpoint_patch_rename_asset(self): + # rename should not be allowed + collection_name = self.collection['name'] + asset_name = self.asset['name'] + new_asset_name = "new-asset-name.txt" + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, name=new_asset_name, sample='asset-1-updated' + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch( + path, data=changed_asset.get_json('patch'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'id': 'Renaming is not allowed'}, + response.json()['description'], + msg='Unexpected error message') + + # Check the data by reading it back + response = self.client.get( + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + ) + json_data = response.json() + self.assertStatusCode(200, response) + + self.assertEqual(asset_name, json_data['id']) + + # Check the data that no new entry exist + response = self.client.get( + f'/{STAC_BASE_V}/collections/{collection_name}/assets/{new_asset_name}' + ) + + # 404 - not found + self.assertStatusCode(404, response) + + def test_asset_endpoint_patch_extra_payload(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + extra_payload='invalid' + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch( + path, data=changed_asset.get_json('patch'), content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'extra_payload': ['Unexpected property in payload']}, + response.json()['description'], + msg='Unexpected error message') + + def test_asset_endpoint_patch_read_only_in_payload(self): + collection_name = self.collection['name'] + asset_name = self.asset['name'] + changed_asset = self.factory.create_collection_asset_sample( + collection=self.collection.model, + name=asset_name, + sample='asset-1-updated', + media_type=self.asset['media_type'], + created=utc_aware(datetime.utcnow()) + ) + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch( + path, + data=changed_asset.get_json('patch', keep_read_only=True), + content_type="application/json" + ) + self.assertStatusCode(400, response) + self.assertEqual({'created': ['Found read-only property in payload']}, + response.json()['description'], + msg='Unexpected error message') + + def test_asset_atomic_upsert_create_500(self): + sample = self.factory.create_collection_asset_sample( + self.collection.model, create_asset_file=True + ) + + # the dataset to update does not exist yet + with self.settings(DEBUG_PROPAGATE_API_EXCEPTIONS=True), disableLogger('stac_api.apps'): + response = self.client.put( + reverse( + 'test-collection-asset-detail-http-500', + args=[self.collection['name'], sample['name']] + ), + data=sample.get_json('put'), + content_type='application/json' + ) + self.assertStatusCode(500, response) + self.assertEqual(response.json()['description'], "AttributeError('test exception')") + + # Make sure that the ressource has not been created + response = self.client.get( + reverse_version( + 'collection-asset-detail', args=[self.collection['name'], sample['name']] + ) + ) + self.assertStatusCode(404, response) + + def test_asset_atomic_upsert_update_500(self): + sample = self.factory.create_collection_asset_sample( + self.collection.model, name=self.asset['name'], create_asset_file=True + ) + + # Make sure samples is different from actual data + self.assertNotEqual(sample.attributes, self.asset.attributes) + + # the dataset to update does not exist yet + with self.settings(DEBUG_PROPAGATE_API_EXCEPTIONS=True), disableLogger('stac_api.apps'): + # because we explicitely test a crash here we don't want to print a CRITICAL log on the + # console therefore disable it. + response = self.client.put( + reverse( + 'test-collection-asset-detail-http-500', + args=[self.collection['name'], sample['name']] + ), + data=sample.get_json('put'), + content_type='application/json' + ) + self.assertStatusCode(500, response) + self.assertEqual(response.json()['description'], "AttributeError('test exception')") + + # Make sure that the ressource has not been created + response = self.client.get( + reverse_version( + 'collection-asset-detail', args=[self.collection['name'], sample['name']] + ) + ) + self.assertStatusCode(200, response) + self.check_stac_collection_asset( + self.asset.json, response.json(), self.collection['name'], ignore=['item'] + ) + + +class CollectionAssetRaceConditionTest(StacBaseTransactionTestCase): + + def setUp(self): + self.username = 'user' + self.password = 'dummy-password' + get_user_model().objects.create_superuser(self.username, password=self.password) + self.factory = Factory() + self.collection_sample = self.factory.create_collection_sample( + sample='collection-2', db_create=True + ) + + def test_asset_upsert_race_condition(self): + workers = 5 + status_201 = 0 + asset_sample = self.factory.create_collection_asset_sample( + self.collection_sample.model, + sample='asset-no-checksum', + ) + + def asset_atomic_upsert_test(worker): + # This method run on separate thread therefore it requires to create a new client and + # to login it for each call. + client = Client() + client.login(username=self.username, password=self.password) + return client.put( + reverse_version( + 'collection-asset-detail', + args=[self.collection_sample['name'], asset_sample['name']] + ), + data=asset_sample.get_json('put'), + content_type='application/json' + ) + + # We call the PUT asset several times in parallel with the same data to make sure + # that we don't have any race condition. + responses, errors = self.run_parallel(workers, asset_atomic_upsert_test) + + for worker, response in responses: + if response.status_code == 201: + status_201 += 1 + self.assertIn( + response.status_code, [200, 201], + msg=f'Unexpected response status code {response.status_code} for worker {worker}' + ) + self.check_stac_collection_asset( + asset_sample.json, response.json(), self.collection_sample['name'], ignore=['item'] + ) + self.assertEqual(status_201, 1, msg="Not only one upsert did a create !") + + +class CollectionAssetsDeleteEndpointTestCase(StacBaseTestCase, S3TestMixin): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample(collection=self.collection).model + self.client = Client() + client_login(self.client) + self.maxDiff = None # pylint: disable=invalid-name + + def test_asset_endpoint_delete_asset(self): + collection_name = self.collection.name + asset_name = self.asset.name + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + s3_path = get_collection_asset_path(self.collection, asset_name) + self.assertS3ObjectExists(s3_path) + response = self.client.delete(path) + self.assertStatusCode(200, response) + + # Check that is has really been deleted + self.assertS3ObjectNotExists(s3_path) + response = self.client.get(path) + self.assertStatusCode(404, response) + + # Check that it is really not to be found in DB + self.assertFalse( + CollectionAsset.objects.filter(name=self.asset.name).exists(), + msg="Deleted asset still found in DB" + ) + + def test_asset_endpoint_delete_asset_invalid_name(self): + collection_name = self.collection.name + path = f"/{STAC_BASE_V}/collections/{collection_name}/assets/non-existent-asset" + response = self.client.delete(path) + self.assertStatusCode(404, response) + + +class CollectionAssetsEndpointUnauthorizedTestCase(StacBaseTestCase): + + @mock_s3_asset_file + def setUp(self): # pylint: disable=invalid-name + self.factory = Factory() + self.collection = self.factory.create_collection_sample().model + self.asset = self.factory.create_collection_asset_sample(collection=self.collection).model + self.client = Client() + + def test_unauthorized_asset_post_put_patch_delete(self): + collection_name = self.collection.name + asset_name = self.asset.name + + new_asset = self.factory.create_collection_asset_sample(collection=self.collection).json + updated_asset = self.factory.create_collection_asset_sample( + collection=self.collection, name=asset_name, sample='asset-1-updated' + ).get_json('post') + + # make sure POST fails for anonymous user: + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets' + response = self.client.post(path, data=new_asset, content_type="application/json") + self.assertStatusCode(401, response, msg="Unauthorized post was permitted.") + + # make sure PUT fails for anonymous user: + + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.put(path, data=updated_asset, content_type="application/json") + self.assertStatusCode(401, response, msg="Unauthorized put was permitted.") + + # make sure PATCH fails for anonymous user: + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.patch(path, data=updated_asset, content_type="application/json") + self.assertStatusCode(401, response, msg="Unauthorized patch was permitted.") + + # make sure DELETE fails for anonymous user: + path = f'/{STAC_BASE_V}/collections/{collection_name}/assets/{asset_name}' + response = self.client.delete(path) + self.assertStatusCode(401, response, msg="Unauthorized del was permitted.") diff --git a/app/tests/tests_10/test_serializer.py b/app/tests/tests_10/test_serializer.py index e13096ae..79952a5c 100644 --- a/app/tests/tests_10/test_serializer.py +++ b/app/tests/tests_10/test_serializer.py @@ -13,9 +13,9 @@ from rest_framework.test import APIRequestFactory from stac_api.models import get_asset_path -from stac_api.serializers import AssetSerializer -from stac_api.serializers import CollectionSerializer -from stac_api.serializers import ItemSerializer +from stac_api.serializers.collection import CollectionSerializer +from stac_api.serializers.item import AssetSerializer +from stac_api.serializers.item import ItemSerializer from stac_api.utils import get_link from stac_api.utils import isoformat from stac_api.utils import utc_aware diff --git a/app/tests/tests_10/test_serializer_asset_upload.py b/app/tests/tests_10/test_serializer_asset_upload.py index b389b256..c5d359a2 100644 --- a/app/tests/tests_10/test_serializer_asset_upload.py +++ b/app/tests/tests_10/test_serializer_asset_upload.py @@ -7,7 +7,7 @@ from rest_framework.exceptions import ValidationError from stac_api.models import AssetUpload -from stac_api.serializers import AssetUploadSerializer +from stac_api.serializers.upload import AssetUploadSerializer from stac_api.utils import get_sha256_multihash from tests.tests_10.base_test import StacBaseTestCase diff --git a/scripts/fill_local_db.py b/scripts/fill_local_db.py index ca8d6871..6db19f8e 100644 --- a/scripts/fill_local_db.py +++ b/scripts/fill_local_db.py @@ -20,10 +20,10 @@ from rest_framework.renderers import JSONRenderer from rest_framework.parsers import JSONParser from stac_api.models import * -from stac_api.serializers import CollectionSerializer -from stac_api.serializers import CollectionSerializer -from stac_api.serializers import LinkSerializer -from stac_api.serializers import ProviderSerializer +from stac_api.serializers.general import CollectionSerializer +from stac_api.serializers.general import CollectionSerializer +from stac_api.serializers.general import LinkSerializer +from stac_api.serializers.general import ProviderSerializer # create link instances for testing