diff --git a/aws-storage-s3/src/androidTest/java/com/amplifyframework/storage/s3/transfer/TransferDBTest.kt b/aws-storage-s3/src/androidTest/java/com/amplifyframework/storage/s3/transfer/TransferDBTest.kt index 199428831b..1922c896a1 100644 --- a/aws-storage-s3/src/androidTest/java/com/amplifyframework/storage/s3/transfer/TransferDBTest.kt +++ b/aws-storage-s3/src/androidTest/java/com/amplifyframework/storage/s3/transfer/TransferDBTest.kt @@ -31,6 +31,7 @@ import org.junit.Test open class TransferDBTest { private val bucketName = "bucket_name" + private val region = "us-east-1" private val fileKey = "file_key" private lateinit var transferDB: TransferDB private lateinit var tempFile: File @@ -55,6 +56,7 @@ open class TransferDBTest { transferId, TransferType.UPLOAD, bucketName, + region, fileKey, tempFile, null, @@ -67,6 +69,7 @@ open class TransferDBTest { Assert.assertEquals(tempFile, File(this.file)) Assert.assertEquals(fileKey, this.key) Assert.assertEquals(bucketName, this.bucketName) + Assert.assertEquals(region, this.region) } ?: Assert.fail("InsertedRecord is null") } @@ -76,6 +79,7 @@ open class TransferDBTest { val uri = transferDB.insertMultipartUploadRecord( uploadID, bucketName, + region, fileKey, tempFile, 1L, @@ -91,6 +95,7 @@ open class TransferDBTest { Assert.assertEquals(fileKey, this.key) Assert.assertEquals(bucketName, this.bucketName) Assert.assertEquals(uploadID, this.multipartId) + Assert.assertEquals(region, this.region) } ?: Assert.fail("InsertedRecord is null") } @@ -104,6 +109,7 @@ open class TransferDBTest { contentValues[0] = transferDB.generateContentValuesForMultiPartUpload( key, bucketName, + region, key, tempFile, 0L, @@ -137,6 +143,7 @@ open class TransferDBTest { contentValues[0] = transferDB.generateContentValuesForMultiPartUpload( key, bucketName, + region, key, tempFile, 0L, @@ -151,6 +158,7 @@ open class TransferDBTest { contentValues[1] = transferDB.generateContentValuesForMultiPartUpload( key, bucketName, + region, key, tempFile, 0L, @@ -165,6 +173,7 @@ open class TransferDBTest { contentValues[2] = transferDB.generateContentValuesForMultiPartUpload( key, bucketName, + region, key, tempFile, 0L, diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/AWSS3StoragePlugin.java b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/AWSS3StoragePlugin.java index 01fde9ba41..46316bef0c 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/AWSS3StoragePlugin.java +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/AWSS3StoragePlugin.java @@ -90,7 +90,9 @@ import com.amplifyframework.storage.s3.request.AWSS3StorageRemoveRequest; import com.amplifyframework.storage.s3.request.AWSS3StorageUploadRequest; import com.amplifyframework.storage.s3.service.AWSS3StorageService; -import com.amplifyframework.storage.s3.service.StorageService; +import com.amplifyframework.storage.s3.service.AWSS3StorageServiceContainer; +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider; +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider; import com.amplifyframework.storage.s3.transfer.TransferObserver; import com.amplifyframework.storage.s3.transfer.TransferRecord; import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater; @@ -101,9 +103,7 @@ import java.io.File; import java.io.InputStream; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -126,20 +126,33 @@ public final class AWSS3StoragePlugin extends StoragePlugin { private static final int DEFAULT_URL_EXPIRATION_DAYS = 7; - private final StorageService.Factory storageServiceFactory; + private final AWSS3StorageService.Factory storageServiceFactory; private final ExecutorService executorService; - private final AuthCredentialsProvider authCredentialsProvider; + private AuthCredentialsProvider authCredentialsProvider; private final AWSS3StoragePluginConfiguration awsS3StoragePluginConfiguration; private AWSS3StorageService defaultStorageService; @SuppressWarnings("deprecation") private StorageAccessLevel defaultAccessLevel; private int defaultUrlExpiration; - private Map awsS3StorageServicesByBucketName = new HashMap<>(); - private Context context; + private AWSS3StorageServiceContainer awss3StorageServiceContainer; @SuppressLint("UnsafeOptInUsageError") private List configuredBuckets; + @SuppressLint("UnsafeOptInUsageError") + private StorageTransferClientProvider clientProvider + = new S3StorageTransferClientProvider((region, bucketName) -> { + if (region != null && bucketName != null) { + StorageBucket bucket = StorageBucket.fromBucketInfo(new BucketInfo(bucketName, region)); + return awss3StorageServiceContainer.get((ResolvedStorageBucket) bucket).getClient(); + } + + if (region != null) { + return S3StorageTransferClientProvider.getS3Client(region, authCredentialsProvider); + } + return defaultStorageService.getClient(); + }); + /** * Constructs the AWS S3 Storage Plugin initializing the executor service. */ @@ -162,13 +175,14 @@ public AWSS3StoragePlugin(AWSS3StoragePluginConfiguration awsS3StoragePluginConf @VisibleForTesting AWSS3StoragePlugin(AuthCredentialsProvider authCredentialsProvider) { - this((context, region, bucket) -> + this((context, region, bucket, clientProvider) -> new AWSS3StorageService( context, region, bucket, authCredentialsProvider, - AWS_S3_STORAGE_PLUGIN_KEY + AWS_S3_STORAGE_PLUGIN_KEY, + clientProvider ), authCredentialsProvider, new AWSS3StoragePluginConfiguration.Builder().build()); @@ -177,13 +191,15 @@ public AWSS3StoragePlugin(AWSS3StoragePluginConfiguration awsS3StoragePluginConf @VisibleForTesting AWSS3StoragePlugin(AuthCredentialsProvider authCredentialsProvider, AWSS3StoragePluginConfiguration awss3StoragePluginConfiguration) { - this((context, region, bucket) -> + + this((context, region, bucket, clientProvider) -> new AWSS3StorageService( context, region, bucket, authCredentialsProvider, - AWS_S3_STORAGE_PLUGIN_KEY + AWS_S3_STORAGE_PLUGIN_KEY, + clientProvider ), authCredentialsProvider, awss3StoragePluginConfiguration); @@ -191,7 +207,7 @@ public AWSS3StoragePlugin(AWSS3StoragePluginConfiguration awsS3StoragePluginConf @VisibleForTesting AWSS3StoragePlugin( - StorageService.Factory storageServiceFactory, + AWSS3StorageService.Factory storageServiceFactory, AuthCredentialsProvider authCredentialsProvider, AWSS3StoragePluginConfiguration awss3StoragePluginConfiguration ) { @@ -281,13 +297,15 @@ private void configure( @NonNull ResolvedStorageBucket bucket ) throws StorageException { try { - this.context = context; - this.defaultStorageService = (AWSS3StorageService) storageServiceFactory.create( + this.defaultStorageService = storageServiceFactory.create( context, region, - bucket.getBucketInfo().getName()); - this.awsS3StorageServicesByBucketName.clear(); - this.awsS3StorageServicesByBucketName.put(bucket.getBucketInfo().getName(), this.defaultStorageService); + bucket.getBucketInfo().getName(), + clientProvider); + this.awss3StorageServiceContainer = new AWSS3StorageServiceContainer( + context, storageServiceFactory, + (S3StorageTransferClientProvider) clientProvider); + this.awss3StorageServiceContainer.put(bucket.getBucketInfo().getName(), this.defaultStorageService); } catch (RuntimeException exception) { throw new StorageException( "Failed to create storage service.", @@ -935,7 +953,8 @@ public StorageRemoveOperation remove( return operation; } - + + @SuppressLint("UnsafeOptInUsageError") @Override @SuppressWarnings("deprecation") public void getTransfer( @@ -951,18 +970,23 @@ public void getTransfer( transferRecord.getId(), defaultStorageService.getTransferManager().getTransferStatusUpdater(), transferRecord.getBucketName(), + transferRecord.getRegion(), transferRecord.getKey(), transferRecord.getFile(), null, transferRecord.getState() != null ? transferRecord.getState() : TransferState.UNKNOWN); TransferType transferType = transferRecord.getType(); + + AWSS3StorageService storageService + = getAwss3StorageServiceFromTransferRecord(onError, transferRecord); + switch (Objects.requireNonNull(transferType)) { case UPLOAD: if (transferRecord.getFile().startsWith(TransferStatusUpdater.TEMP_FILE_PREFIX)) { AWSS3StorageUploadInputStreamOperation operation = new AWSS3StorageUploadInputStreamOperation( transferId, - defaultStorageService, + storageService, executorService, authCredentialsProvider, awsS3StoragePluginConfiguration, @@ -973,7 +997,7 @@ public void getTransfer( AWSS3StorageUploadFileOperation operation = new AWSS3StorageUploadFileOperation( transferId, - defaultStorageService, + storageService, executorService, authCredentialsProvider, awsS3StoragePluginConfiguration, @@ -987,7 +1011,7 @@ public void getTransfer( downloadFileOperation = new AWSS3StorageDownloadFileOperation( transferId, new File(transferRecord.getFile()), - defaultStorageService, + storageService, executorService, authCredentialsProvider, awsS3StoragePluginConfiguration, @@ -1009,6 +1033,25 @@ public void getTransfer( }); } + private AWSS3StorageService getAwss3StorageServiceFromTransferRecord( + @NonNull Consumer onError, + TransferRecord transferRecord + ) { + AWSS3StorageService storageService = defaultStorageService; + if (transferRecord.getRegion() != null && transferRecord.getBucketName() != null) { + try { + BucketInfo bucketInfo = new BucketInfo( + transferRecord.getBucketName(), + transferRecord.getRegion()); + StorageBucket bucket = StorageBucket.fromBucketInfo(bucketInfo); + storageService = getStorageService(bucket); + } catch (StorageException exception) { + onError.accept(exception); + } + } + return storageService; + } + @NonNull @SuppressWarnings("deprecation") @Override @@ -1105,55 +1148,27 @@ AWSS3StorageService getStorageService(@Nullable StorageBucket bucket) throws Sto } if (bucket instanceof OutputsStorageBucket) { - AWSS3StorageService service = getAWSS3StorageService((OutputsStorageBucket) bucket); - if (service == null) { - throw new StorageException( - "Unable to find bucket from name in Amplify Outputs.", - new InvalidStorageBucketException(), - "Ensure the bucket name used is available in Amplify Outputs."); - } else { - return service; - } - } - - if (bucket instanceof ResolvedStorageBucket) { - return getAWSS3StorageService((ResolvedStorageBucket) bucket); - } - - return defaultStorageService; - } - - @SuppressLint("UnsafeOptInUsageError") - private AWSS3StorageService getAWSS3StorageService(OutputsStorageBucket outputsStorageBucket) { - if (configuredBuckets != null && !configuredBuckets.isEmpty()) { - String name = outputsStorageBucket.getName(); - for (AmplifyOutputsData.StorageBucket configuredBucket : configuredBuckets) { - if (configuredBucket.getName().equals(name)) { - String bucketName = configuredBucket.getBucketName(); - AWSS3StorageService service = awsS3StorageServicesByBucketName.get(bucketName); - if (service == null) { + if (configuredBuckets != null && !configuredBuckets.isEmpty()) { + String name = ((OutputsStorageBucket) bucket).getName(); + for (AmplifyOutputsData.StorageBucket configuredBucket : configuredBuckets) { + if (configuredBucket.getName().equals(name)) { + String bucketName = configuredBucket.getBucketName(); String region = configuredBucket.getAwsRegion(); - service = (AWSS3StorageService) storageServiceFactory.create(context, region, bucketName); - awsS3StorageServicesByBucketName.put(bucketName, service); + return awss3StorageServiceContainer.get(bucketName, region); } - - return service; } } + throw new StorageException( + "Unable to find bucket from name in Amplify Outputs.", + new InvalidStorageBucketException(), + "Ensure the bucket name used is available in Amplify Outputs."); } - return null; - } - @SuppressLint("UnsafeOptInUsageError") - private AWSS3StorageService getAWSS3StorageService(ResolvedStorageBucket resolvedStorageBucket) { - String bucketName = resolvedStorageBucket.getBucketInfo().getName(); - AWSS3StorageService service = awsS3StorageServicesByBucketName.get(bucketName); - if (service == null) { - String region = resolvedStorageBucket.getBucketInfo().getRegion(); - service = (AWSS3StorageService) storageServiceFactory.create(context, region, bucketName); - awsS3StorageServicesByBucketName.put(bucketName, service); + if (bucket instanceof ResolvedStorageBucket) { + return awss3StorageServiceContainer.get((ResolvedStorageBucket) bucket); } - return service; + + return defaultStorageService; } /** diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/TransferOperations.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/TransferOperations.kt index 47cd602c41..99d15ea9d8 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/TransferOperations.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/TransferOperations.kt @@ -70,6 +70,7 @@ internal object TransferOperations { transferRecord.id, transferStatusUpdater, transferRecord.bucketName, + transferRecord.region, transferRecord.key, transferRecord.file, listener diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageService.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageService.kt index 7d9eef745d..eb7e573c2f 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageService.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageService.kt @@ -32,6 +32,8 @@ import com.amplifyframework.storage.StorageItem import com.amplifyframework.storage.options.SubpathStrategy import com.amplifyframework.storage.options.SubpathStrategy.Exclude import com.amplifyframework.storage.result.StorageListResult +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferManager import com.amplifyframework.storage.s3.transfer.TransferObserver import com.amplifyframework.storage.s3.transfer.TransferRecord @@ -46,7 +48,6 @@ import java.util.Date import kotlin.time.Duration.Companion.seconds import kotlin.time.ExperimentalTime import kotlinx.coroutines.runBlocking - /** * A representation of an S3 backend service endpoint. */ @@ -55,16 +56,14 @@ internal class AWSS3StorageService( private val awsRegion: String, private val s3BucketName: String, private val authCredentialsProvider: AuthCredentialsProvider, - private val awsS3StoragePluginKey: String + private val awsS3StoragePluginKey: String, + private val clientProvider: StorageTransferClientProvider ) : StorageService { - private var s3Client: S3Client = S3Client { - region = awsRegion - credentialsProvider = authCredentialsProvider - } + private var s3Client: S3Client = S3StorageTransferClientProvider.getS3Client(awsRegion, authCredentialsProvider) val transferManager: TransferManager = - TransferManager(context, s3Client, awsS3StoragePluginKey) + TransferManager(context, clientProvider, awsS3StoragePluginKey) /** * Generate pre-signed URL for an object. @@ -130,6 +129,7 @@ internal class AWSS3StorageService( return transferManager.download( transferId, s3BucketName, + awsRegion, serviceKey, file, useAccelerateEndpoint = useAccelerateEndpoint @@ -153,6 +153,7 @@ internal class AWSS3StorageService( return transferManager.upload( transferId, s3BucketName, + awsRegion, serviceKey, file, metadata, @@ -175,7 +176,7 @@ internal class AWSS3StorageService( metadata: ObjectMetadata, useAccelerateEndpoint: Boolean ): TransferObserver { - val uploadOptions = UploadOptions(s3BucketName, metadata) + val uploadOptions = UploadOptions(s3BucketName, awsRegion, metadata) return transferManager.upload(transferId, serviceKey, inputStream, uploadOptions, useAccelerateEndpoint) } @@ -420,4 +421,21 @@ internal class AWSS3StorageService( fun getClient(): S3Client { return s3Client } + + interface Factory { + /** + * Factory interface to instantiate [StorageService] object. + * + * @param context Android context + * @param region S3 bucket region + * @param bucketName Name of the bucket where the items are stored + * @return An instantiated storage service instance + */ + fun create( + context: Context, + region: String, + bucketName: String, + clientProvider: StorageTransferClientProvider + ): AWSS3StorageService + } } diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageServiceContainer.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageServiceContainer.kt new file mode 100644 index 0000000000..f22131df2f --- /dev/null +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageServiceContainer.kt @@ -0,0 +1,88 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package com.amplifyframework.storage.s3.service + +import android.content.Context +import com.amplifyframework.storage.ResolvedStorageBucket +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider +import java.util.concurrent.ConcurrentHashMap + +/** + * A container that stores a list of AWSS3StorageService based on the bucket name associated with the service. + * repository. + */ +internal class AWSS3StorageServiceContainer( + private val context: Context, + private val storageServiceFactory: AWSS3StorageService.Factory, + private val clientProvider: StorageTransferClientProvider, + private val awsS3StorageServicesByBucketName: ConcurrentHashMap +) { + constructor( + context: Context, + storageServiceFactory: AWSS3StorageService.Factory, + clientProvider: S3StorageTransferClientProvider + ) : this(context, storageServiceFactory, clientProvider, ConcurrentHashMap()) + + private val lock = Any() + + /** + * Stores a instance of AWSS3StorageService + * + * @param bucketName the bucket name + * @param service the AWSS3StorageService instance + */ + fun put(bucketName: String, service: AWSS3StorageService) { + synchronized(lock) { + awsS3StorageServicesByBucketName.put(bucketName, service) + } + } + + /** + * Get an AWSS3StorageSErvice instance based on a ResolvedStorageBucket + * @param resolvedStorageBucket An instance of ResolvedStorageBucket with bucket info + * @return An AWSS3StorageService instance associated with the ResolvedStorageBucket + */ + fun get(resolvedStorageBucket: ResolvedStorageBucket): AWSS3StorageService { + synchronized(lock) { + val bucketName: String = resolvedStorageBucket.bucketInfo.name + var service = awsS3StorageServicesByBucketName.get(bucketName) + if (service == null) { + val region: String = resolvedStorageBucket.bucketInfo.region + service = storageServiceFactory.create(context, region, bucketName, clientProvider) + awsS3StorageServicesByBucketName[bucketName] = service + } + return service + } + } + + /** + * Get an AWSS3StorageSErvice instance based on a bucket name and region + * @param bucketName the bucket name associated with the AWSS3StorageService + * @param bucketName the region to associate with a new AWSS3StorageService instance if one doesn't exist + * @return An AWSS3StorageService instance associated with the ResolvedStorageBucket + */ + fun get(bucketName: String, region: String): AWSS3StorageService { + synchronized(lock) { + var service = awsS3StorageServicesByBucketName[bucketName] + if (service == null) { + service = storageServiceFactory.create(context, region, bucketName, clientProvider) + awsS3StorageServicesByBucketName[bucketName] = service + } + + return service + } + } +} diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/S3StorageTransferClientProvider.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/S3StorageTransferClientProvider.kt new file mode 100644 index 0000000000..49d976f7b3 --- /dev/null +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/S3StorageTransferClientProvider.kt @@ -0,0 +1,35 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package com.amplifyframework.storage.s3.transfer + +import aws.sdk.kotlin.services.s3.S3Client +import com.amplifyframework.auth.AuthCredentialsProvider + +internal class S3StorageTransferClientProvider( + private val createS3Client: (region: String?, bucketName: String?) -> S3Client +) : StorageTransferClientProvider { + companion object { + @JvmStatic + fun getS3Client(region: String, authCredentialsProvider: AuthCredentialsProvider): S3Client { + return S3Client { + this.region = region + this.credentialsProvider = authCredentialsProvider + } + } + } + override fun getStorageTransferClient(region: String?, bucketName: String?): S3Client { + return createS3Client(region, bucketName) + } +} diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/StorageTransferClientProvider.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/StorageTransferClientProvider.kt new file mode 100644 index 0000000000..523c0dcf3e --- /dev/null +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/StorageTransferClientProvider.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package com.amplifyframework.storage.s3.transfer + +import aws.sdk.kotlin.services.s3.S3Client +import com.amplifyframework.annotations.InternalApiWarning + +internal interface StorageTransferClientProvider { + fun getStorageTransferClient(region: String?, bucketName: String?): S3Client +} diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDB.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDB.kt index 6413cf7c45..9c2949654b 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDB.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDB.kt @@ -79,6 +79,7 @@ internal class TransferDB private constructor(context: Context) { fun insertMultipartUploadRecord( transferId: String, bucket: String, + region: String, key: String, file: File, fileOffset: Long, @@ -90,7 +91,7 @@ internal class TransferDB private constructor(context: Context) { ): Uri { val values: ContentValues = generateContentValuesForMultiPartUpload( transferId, - bucket, key, file, + bucket, region, key, file, fileOffset, partNumber, uploadId, bytesTotal, isLastPart, ObjectMetadata(), null, useAccelerateEndpoint @@ -114,6 +115,7 @@ internal class TransferDB private constructor(context: Context) { transferId: String, type: TransferType, bucket: String, + region: String, key: String, file: File?, cannedAcl: ObjectCannedAcl? = null, @@ -124,6 +126,7 @@ internal class TransferDB private constructor(context: Context) { transferId, type, bucket, + region, key, file, metadata, @@ -398,7 +401,7 @@ internal class TransferDB private constructor(context: Context) { */ fun getTransferRecordById(id: Int): TransferRecord? { var transferRecord: TransferRecord? = null - var c: Cursor? = null + var c: Cursor? try { c = queryTransferById(id) c?.use { @@ -415,7 +418,7 @@ internal class TransferDB private constructor(context: Context) { fun getTransferByTransferId(transferId: String): TransferRecord? { var transferRecord: TransferRecord? = null - var c: Cursor? = null + var c: Cursor? try { c = transferDBHelper.query(getTransferRecordIdUri(transferId)) c.use { @@ -560,6 +563,7 @@ internal class TransferDB private constructor(context: Context) { transferId: String, type: TransferType, bucket: String, + region: String, key: String, file: File, metadata: ObjectMetadata?, @@ -570,6 +574,7 @@ internal class TransferDB private constructor(context: Context) { transferId, type, bucket, + region, key, file, metadata, @@ -602,6 +607,7 @@ internal class TransferDB private constructor(context: Context) { fun generateContentValuesForMultiPartUpload( transferId: String, bucket: String?, + region: String?, key: String?, file: File, fileOffset: Long, @@ -618,6 +624,7 @@ internal class TransferDB private constructor(context: Context) { values.put(TransferTable.COLUMN_TYPE, TransferType.UPLOAD.toString()) values.put(TransferTable.COLUMN_STATE, TransferState.WAITING.toString()) values.put(TransferTable.COLUMN_BUCKET_NAME, bucket) + values.put(TransferTable.COLUMN_REGION, region) values.put(TransferTable.COLUMN_KEY, key) values.put(TransferTable.COLUMN_FILE, file.absolutePath) values.put(TransferTable.COLUMN_BYTES_CURRENT, 0L) @@ -723,6 +730,7 @@ internal class TransferDB private constructor(context: Context) { transferId: String, type: TransferType, bucket: String, + region: String, key: String, file: File?, metadata: ObjectMetadata?, @@ -734,6 +742,7 @@ internal class TransferDB private constructor(context: Context) { values.put(TransferTable.COLUMN_TYPE, type.toString()) values.put(TransferTable.COLUMN_STATE, TransferState.WAITING.toString()) values.put(TransferTable.COLUMN_BUCKET_NAME, bucket) + values.put(TransferTable.COLUMN_REGION, region) values.put(TransferTable.COLUMN_KEY, key) values.put(TransferTable.COLUMN_FILE, file?.absolutePath) values.put(TransferTable.COLUMN_BYTES_CURRENT, 0L) diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDBHelper.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDBHelper.kt index 211ea745e5..4403d51df4 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDBHelper.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDBHelper.kt @@ -48,7 +48,7 @@ internal class TransferDBHelper(private val context: Context) : SQLiteOpenHelper // This represents the latest database version. // Update this when the database is being upgraded. - private const val DATABASE_VERSION = 9 + private const val DATABASE_VERSION = 10 private const val BASE_PATH = "transfers" private const val TRANSFERS = 10 private const val TRANSFER_ID = 20 diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferManager.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferManager.kt index 8a4bd9cb30..119aeeafb9 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferManager.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferManager.kt @@ -45,7 +45,7 @@ import kotlin.math.min */ internal class TransferManager( context: Context, - s3: S3Client, + clientProvider: StorageTransferClientProvider, private val pluginKey: String, private val workManager: WorkManager = WorkManager.getInstance(context) ) { @@ -71,7 +71,7 @@ internal class TransferManager( init { RouterWorker.workerFactories[pluginKey] = TransferWorkerFactory( transferDB, - s3, + clientProvider, transferStatusUpdater ) } @@ -93,6 +93,7 @@ internal class TransferManager( fun upload( transferId: String, bucket: String, + region: String, key: String, file: File, metadata: ObjectMetadata, @@ -101,12 +102,13 @@ internal class TransferManager( useAccelerateEndpoint: Boolean = false ): TransferObserver { val transferRecordId = if (shouldUploadInMultipart(file)) { - createMultipartUploadRecords(transferId, bucket, key, file, metadata, cannedAcl, useAccelerateEndpoint) + createMultipartUploadRecords(transferId, bucket, region, key, file, metadata, cannedAcl, useAccelerateEndpoint) } else { val uri = transferDB.insertSingleTransferRecord( transferId, TransferType.UPLOAD, bucket, + region, key, file, cannedAcl, @@ -147,6 +149,7 @@ internal class TransferManager( return upload( transferId, options.bucket, + options.region, key, file, options.objectMetadata, @@ -160,6 +163,7 @@ internal class TransferManager( fun download( transferId: String, bucket: String, + region: String, key: String, file: File, listener: TransferListener? = null, @@ -172,6 +176,7 @@ internal class TransferManager( transferId, TransferType.DOWNLOAD, bucket, + region, key, file, useAccelerateEndpoint = useAccelerateEndpoint @@ -246,6 +251,7 @@ internal class TransferManager( private fun createMultipartUploadRecords( transferId: String, bucket: String, + region: String, key: String, file: File, metadata: ObjectMetadata, @@ -263,6 +269,7 @@ internal class TransferManager( contentValues[0] = transferDB.generateContentValuesForMultiPartUpload( transferId, bucket, + region, key, file, fileOffset, @@ -279,6 +286,7 @@ internal class TransferManager( contentValues[partNum] = transferDB.generateContentValuesForMultiPartUpload( UUID.randomUUID().toString(), bucket, + region, key, file, fileOffset, diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferObserver.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferObserver.kt index 021e06a5b8..b6147f5cfe 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferObserver.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferObserver.kt @@ -24,6 +24,7 @@ internal data class TransferObserver @JvmOverloads constructor( val id: Int, private val transferStatusUpdater: TransferStatusUpdater, val bucket: String? = null, + val region: String? = null, val key: String? = null, val filePath: String? = null, private val listener: TransferListener? = null, diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferRecord.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferRecord.kt index c32a85c7ce..73d219b97d 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferRecord.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferRecord.kt @@ -36,6 +36,7 @@ internal data class TransferRecord( var type: TransferType? = null, var state: TransferState? = null, var bucketName: String? = null, + var region: String? = null, var key: String? = null, var versionId: String? = null, var file: String = "", @@ -80,6 +81,8 @@ internal data class TransferRecord( ) this.bucketName = c.getString(c.getColumnIndexOrThrow(TransferTable.COLUMN_BUCKET_NAME)) + this.region = + c.getString(c.getColumnIndexOrThrow(TransferTable.COLUMN_REGION)) this.key = c.getString(c.getColumnIndexOrThrow(TransferTable.COLUMN_KEY)) this.versionId = c.getString(c.getColumnIndexOrThrow(TransferTable.COLUMN_VERSION_ID)) diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferTable.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferTable.kt index 85b515062d..d7417c4937 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferTable.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferTable.kt @@ -134,6 +134,8 @@ internal class TransferTable { const val COLUMN_USE_ACCELERATE_ENDPOINT = "useAccelerateEndpoint" + const val COLUMN_REGION = "region" + private const val TABLE_VERSION_2 = 2 private const val TABLE_VERSION_3 = 3 private const val TABLE_VERSION_4 = 4 @@ -142,8 +144,13 @@ internal class TransferTable { private const val TABLE_VERSION_7 = 7 private const val TABLE_VERSION_8 = 8 private const val TABLE_VERSION_9 = 9 + private const val TABLE_VERSION_10 = 10 - // Database creation SQL statement + // **** DO NOT UPDATE *** + // Database creation SQL statement for TABLE_VERSION 1 + // The current database migration implementation assumes that the original version 1 is always created + // and then incrementally upgrades from the original version 1 to latest version. + // instead of of upgrading from the last/previous version to the latest version. const val DATABASE_CREATE = "create table $TABLE_TRANSFER (" + "$COLUMN_ID integer primary key autoincrement, " + "$COLUMN_MAIN_UPLOAD_ID integer, " + @@ -219,6 +226,9 @@ internal class TransferTable { if (TABLE_VERSION_9 in (oldVersion + 1)..newVersion) { addVersion9Columns(database) } + if (TABLE_VERSION_10 in (oldVersion + 1)..newVersion) { + addVersion10Columns(database) + } database.setTransactionSuccessful() database.endTransaction() } @@ -296,5 +306,10 @@ internal class TransferTable { "DEFAULT 0;" database.execSQL(addConnectionType) } + + private fun addVersion10Columns(database: SQLiteDatabase) { + val addRegion = "ALTER TABLE $TABLE_TRANSFER ADD COLUMN $COLUMN_REGION text " + "DEFAULT null;" + database.execSQL(addRegion) + } } } diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/UploadOptions.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/UploadOptions.kt index f68453c757..635bc7eb7f 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/UploadOptions.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/UploadOptions.kt @@ -24,6 +24,7 @@ import com.amplifyframework.storage.ObjectMetadata internal data class UploadOptions @JvmOverloads constructor( val bucket: String, + val region: String, val objectMetadata: ObjectMetadata = ObjectMetadata(), val cannedAcl: ObjectCannedAcl? = null, val transferListener: TransferListener? = null diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorker.kt index 30aa2d20b9..3f9cb5efb9 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorker.kt @@ -20,6 +20,7 @@ import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.abortMultipartUpload import aws.sdk.kotlin.services.s3.withConfig import com.amplifyframework.storage.TransferState +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -27,7 +28,7 @@ import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater * Worker to abort pending multipart upload **/ internal class AbortMultiPartUploadWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -35,6 +36,7 @@ internal class AbortMultiPartUploadWorker( ) : BaseTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) { override suspend fun performWork(): Result { + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) return s3.withConfig { enableAccelerate = transferRecord.useAccelerateEndpoint == 1 }.abortMultipartUpload { diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/CompleteMultiPartUploadWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/CompleteMultiPartUploadWorker.kt index 8b48014a87..d7cadd5dd2 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/CompleteMultiPartUploadWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/CompleteMultiPartUploadWorker.kt @@ -20,6 +20,7 @@ import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.completeMultipartUpload import aws.sdk.kotlin.services.s3.model.CompletedMultipartUpload import aws.sdk.kotlin.services.s3.withConfig +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -27,7 +28,7 @@ import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater * Worker to complete multipart upload **/ internal class CompleteMultiPartUploadWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -36,6 +37,7 @@ internal class CompleteMultiPartUploadWorker( override suspend fun performWork(): Result { val completedParts = transferDB.queryPartETagsOfUpload(transferRecord.id) + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) return s3.withConfig { enableAccelerate = transferRecord.useAccelerateEndpoint == 1 }.completeMultipartUpload { diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorker.kt index 87a9db5612..9af6f8d70d 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorker.kt @@ -26,6 +26,7 @@ import aws.smithy.kotlin.runtime.io.SdkSource import aws.smithy.kotlin.runtime.io.buffer import com.amplifyframework.storage.s3.transfer.DownloadProgressListener import com.amplifyframework.storage.s3.transfer.DownloadProgressListenerInterceptor +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater import java.io.BufferedOutputStream @@ -40,7 +41,7 @@ import kotlinx.coroutines.withContext * Worker to perform download file task. */ internal class DownloadWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -50,6 +51,7 @@ internal class DownloadWorker( private lateinit var downloadProgressListener: DownloadProgressListener private val defaultBufferSize = 8192L override suspend fun performWork(): Result { + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) s3.withConfig { enableAccelerate = transferRecord.useAccelerateEndpoint == 1 } diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorker.kt index ec7be7cb35..b3c013b4bc 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorker.kt @@ -21,6 +21,7 @@ import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.createMultipartUpload import aws.sdk.kotlin.services.s3.withConfig import com.amplifyframework.storage.TransferState +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -28,7 +29,7 @@ import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater * Worker to initiate multipart upload **/ internal class InitiateMultiPartUploadTransferWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -36,6 +37,7 @@ internal class InitiateMultiPartUploadTransferWorker( ) : BaseTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) { override suspend fun performWork(): Result { + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) transferStatusUpdater.updateTransferState(transferRecord.id, TransferState.IN_PROGRESS) val putObjectRequest = createPutObjectRequest(transferRecord, null) return s3.withConfig { diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/PartUploadTransferWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/PartUploadTransferWorker.kt index d145786b6d..b7a0f6760d 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/PartUploadTransferWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/PartUploadTransferWorker.kt @@ -22,6 +22,7 @@ import aws.sdk.kotlin.services.s3.withConfig import aws.smithy.kotlin.runtime.content.asByteStream import com.amplifyframework.storage.TransferState import com.amplifyframework.storage.s3.transfer.PartUploadProgressListener +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater import com.amplifyframework.storage.s3.transfer.UploadProgressListenerInterceptor @@ -33,7 +34,7 @@ import kotlinx.coroutines.isActive * Worker to upload a part for multipart upload **/ internal class PartUploadTransferWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -51,6 +52,7 @@ internal class PartUploadTransferWorker( transferStatusUpdater.updateTransferState(transferRecord.mainUploadId, TransferState.IN_PROGRESS) multiPartUploadId = inputData.keyValueMap[MULTI_PART_UPLOAD_ID] as String partUploadProgressListener = PartUploadProgressListener(transferRecord, transferStatusUpdater) + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) return s3.withConfig { interceptors += UploadProgressListenerInterceptor(partUploadProgressListener) enableAccelerate = transferRecord.useAccelerateEndpoint == 1 diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/SinglePartUploadWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/SinglePartUploadWorker.kt index 21de0db80c..515c36befc 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/SinglePartUploadWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/SinglePartUploadWorker.kt @@ -22,13 +22,14 @@ import android.content.Context import androidx.work.WorkerParameters import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.withConfig +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater import com.amplifyframework.storage.s3.transfer.UploadProgressListener import com.amplifyframework.storage.s3.transfer.UploadProgressListenerInterceptor internal class SinglePartUploadWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -40,6 +41,7 @@ internal class SinglePartUploadWorker( override suspend fun performWork(): Result { uploadProgressListener = UploadProgressListener(transferRecord, transferStatusUpdater) val putObjectRequest = createPutObjectRequest(transferRecord, uploadProgressListener) + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) return s3.withConfig { interceptors += UploadProgressListenerInterceptor(uploadProgressListener) enableAccelerate = transferRecord.useAccelerateEndpoint == 1 diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/TransferWorkerFactory.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/TransferWorkerFactory.kt index 88c0dc19f4..f491d1ee95 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/TransferWorkerFactory.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/TransferWorkerFactory.kt @@ -18,6 +18,7 @@ import android.content.Context import androidx.work.WorkerFactory import androidx.work.WorkerParameters import aws.sdk.kotlin.services.s3.S3Client +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -26,7 +27,7 @@ import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater **/ internal class TransferWorkerFactory( private val transferDB: TransferDB, - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferStatusUpdater: TransferStatusUpdater ) : WorkerFactory() { override fun createWorker( @@ -37,7 +38,7 @@ internal class TransferWorkerFactory( when (workerClassName) { DownloadWorker::class.java.name -> return DownloadWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -45,7 +46,7 @@ internal class TransferWorkerFactory( ) SinglePartUploadWorker::class.java.name -> return SinglePartUploadWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -53,7 +54,7 @@ internal class TransferWorkerFactory( ) InitiateMultiPartUploadTransferWorker::class.java.name -> return InitiateMultiPartUploadTransferWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -61,7 +62,7 @@ internal class TransferWorkerFactory( ) PartUploadTransferWorker::class.java.name -> return PartUploadTransferWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -69,7 +70,7 @@ internal class TransferWorkerFactory( ) CompleteMultiPartUploadWorker::class.java.name -> return CompleteMultiPartUploadWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -77,7 +78,7 @@ internal class TransferWorkerFactory( ) AbortMultiPartUploadWorker::class.java.name -> return AbortMultiPartUploadWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/AWSS3StorageServiceContainerTest.kt b/aws-storage-s3/src/test/java/com/amplifyframework/storage/AWSS3StorageServiceContainerTest.kt new file mode 100644 index 0000000000..3119dab168 --- /dev/null +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/AWSS3StorageServiceContainerTest.kt @@ -0,0 +1,121 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amplifyframework.storage.s3 + +import android.content.Context +import com.amplifyframework.storage.BucketInfo +import com.amplifyframework.storage.ResolvedStorageBucket +import com.amplifyframework.storage.StorageBucket +import com.amplifyframework.storage.s3.service.AWSS3StorageService +import com.amplifyframework.storage.s3.service.AWSS3StorageServiceContainer +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider +import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe +import io.mockk.every +import io.mockk.mockk +import java.util.concurrent.ConcurrentHashMap +import org.junit.Before +import org.junit.Test + +class AWSS3StorageServiceContainerTest { + + private val storageServiceFactory = mockk { + every { create(any(), any(), any(), any()) } returns mockk() + } + private val context = mockk() + private val clientProvider = mockk() + private val bucketName = "testBucket" + private val region = "us-east-1" + + private lateinit var serviceContainerHashMap: ConcurrentHashMap + private lateinit var serviceContainer: AWSS3StorageServiceContainer + @Before + fun setUp() { + serviceContainerHashMap = ConcurrentHashMap() + serviceContainer = AWSS3StorageServiceContainer( + context, + storageServiceFactory, + clientProvider, + serviceContainerHashMap + ) + } + + @Test + fun `put default AWSS3Service in container`() { + val service = storageServiceFactory.create(context, region, bucketName, clientProvider) + serviceContainer.put(bucketName, service) + + serviceContainerHashMap.size shouldBe 1 + serviceContainerHashMap[bucketName] shouldNotBe null + } + + @Test + fun `get non-existent AWSS3Service in container with ResolvedStorageBucket creates new AWSService`() { + val bucketInfo = BucketInfo(bucketName, region) + val bucket: ResolvedStorageBucket = StorageBucket.fromBucketInfo(bucketInfo) as ResolvedStorageBucket + + val service = serviceContainer.get(bucket) + + service shouldNotBe null + serviceContainerHashMap.size shouldBe 1 + serviceContainerHashMap[bucketName] shouldNotBe null + serviceContainerHashMap[bucketName] shouldBe service + } + + @Test + fun `get WSS3Service in container multiple times with ResolvedStorageBucket creates only one service`() { + val bucketInfo = BucketInfo(bucketName, region) + val bucket: ResolvedStorageBucket = StorageBucket.fromBucketInfo(bucketInfo) as ResolvedStorageBucket + + val service = serviceContainer.get(bucket) + val service2 = serviceContainer.get(bucket) + + service shouldNotBe null + service2 shouldNotBe null + service shouldBe service2 + + serviceContainerHashMap.size shouldBe 1 + serviceContainerHashMap[bucketName] shouldNotBe null + serviceContainerHashMap[bucketName] shouldBe service + serviceContainerHashMap[bucketName] shouldBe service2 + } + + @Test + fun `get non-existent AWSS3Service in container with bucket name and region creates new AWSService`() { + val service = serviceContainer.get(bucketName, region) + + service shouldNotBe null + serviceContainerHashMap.size shouldBe 1 + serviceContainerHashMap[bucketName] shouldNotBe null + serviceContainerHashMap[bucketName] shouldBe service + } + + @Test + fun `get WSS3Service in container multiple times with bucket name and region creates only one service`() { + + val service = serviceContainer.get(bucketName, region) + val service2 = serviceContainer.get(bucketName, region) + + service shouldNotBe null + service2 shouldNotBe null + service shouldBe service2 + + serviceContainerHashMap.size shouldBe 1 + serviceContainerHashMap[bucketName] shouldNotBe null + serviceContainerHashMap[bucketName] shouldBe service + serviceContainerHashMap[bucketName] shouldBe service2 + } +} diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/AWSS3StoragePluginTest.kt b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/AWSS3StoragePluginTest.kt index a13158f1bf..5380af91b6 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/AWSS3StoragePluginTest.kt +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/AWSS3StoragePluginTest.kt @@ -20,7 +20,6 @@ import com.amplifyframework.storage.InvalidStorageBucketException import com.amplifyframework.storage.StorageBucket import com.amplifyframework.storage.StorageException import com.amplifyframework.storage.s3.service.AWSS3StorageService -import com.amplifyframework.storage.s3.service.StorageService import com.amplifyframework.testutils.configuration.amplifyOutputsData import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldNotBe @@ -32,8 +31,8 @@ import org.junit.Test class AWSS3StoragePluginTest { - private val storageServiceFactory = mockk { - every { create(any(), any(), any()) } returns mockk() + private val storageServiceFactory = mockk { + every { create(any(), any(), any(), any()) } returns mockk() } private val plugin = AWSS3StoragePlugin( @@ -54,7 +53,7 @@ class AWSS3StoragePluginTest { plugin.configure(data, mockk()) verify { - storageServiceFactory.create(any(), "test-region", "test-bucket") + storageServiceFactory.create(any(), "test-region", "test-bucket", any()) } } diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/StorageComponentTest.java b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/StorageComponentTest.java index c1ac9ca541..3e238acced 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/StorageComponentTest.java +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/StorageComponentTest.java @@ -32,6 +32,7 @@ import com.amplifyframework.storage.s3.configuration.AWSS3StoragePluginConfiguration; import com.amplifyframework.storage.s3.service.AWSS3StorageService; import com.amplifyframework.storage.s3.service.StorageService; +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider; import com.amplifyframework.storage.s3.transfer.TransferListener; import com.amplifyframework.storage.s3.transfer.TransferObserver; import com.amplifyframework.testutils.Await; @@ -76,6 +77,7 @@ public final class StorageComponentTest { private StorageCategory storage; private StorageService storageService; + private StorageTransferClientProvider clientProvider; /** * Sets up Storage category by registering a mock AWSS3StoragePlugin @@ -88,7 +90,8 @@ public final class StorageComponentTest { public void setup() throws AmplifyException { this.storage = new StorageCategory(); this.storageService = mock(AWSS3StorageService.class); - StorageService.Factory storageServiceFactory = (context, region, bucket) -> storageService; + AWSS3StorageService.Factory storageServiceFactory + = (context, region, bucket, clientProvider) -> (AWSS3StorageService) storageService; AuthCredentialsProvider cognitoAuthProvider = mock(AuthCredentialsProvider.class); doReturn(RandomString.string()).when(cognitoAuthProvider).getIdentityId(null); this.storage.addPlugin(new AWSS3StoragePlugin(storageServiceFactory, diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorkerTest.kt b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorkerTest.kt index c3097d71c0..5772881318 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorkerTest.kt +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorkerTest.kt @@ -24,6 +24,8 @@ import aws.sdk.kotlin.services.s3.model.AbortMultipartUploadRequest import aws.sdk.kotlin.services.s3.model.AbortMultipartUploadResponse import aws.sdk.kotlin.services.s3.withConfig import com.amplifyframework.storage.TransferState +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferRecord import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -52,6 +54,7 @@ internal class AbortMultiPartUploadWorkerTest { private lateinit var transferDB: TransferDB private lateinit var transferStatusUpdater: TransferStatusUpdater private lateinit var workerParameters: WorkerParameters + private lateinit var clientProvider: StorageTransferClientProvider @Before fun setup() { @@ -59,6 +62,7 @@ internal class AbortMultiPartUploadWorkerTest { context = ApplicationProvider.getApplicationContext() workerParameters = mockk(WorkerParameters::class.java.name) s3Client = spyk(recordPrivateCalls = true) + clientProvider = mockk(S3StorageTransferClientProvider::class.java.name) mockkStatic(S3Client::withConfig) transferDB = mockk(TransferDB::class.java.name) transferStatusUpdater = mockk(TransferStatusUpdater::class.java.name) @@ -66,6 +70,7 @@ internal class AbortMultiPartUploadWorkerTest { every { workerParameters.runAttemptCount }.answers { 1 } every { workerParameters.taskExecutor }.answers { ImmediateTaskExecutor() } every { any().withConfig(any()) }.answers { s3Client } + every { clientProvider.getStorageTransferClient(any(), any()) }.answers { s3Client } } @After @@ -95,13 +100,20 @@ internal class AbortMultiPartUploadWorkerTest { every { transferDB.getTransferRecordById(any()) }.answers { transferRecord } every { transferStatusUpdater.updateTransferState(any(), any()) }.answers { } - val worker = AbortMultiPartUploadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = AbortMultiPartUploadWorker( + clientProvider, + transferDB, + transferStatusUpdater, + context, + workerParameters + ) val result = worker.doWork() val expectedResult = ListenableWorker.Result.success(workDataOf(BaseTransferWorker.OUTPUT_TRANSFER_RECORD_ID to 1)) verify(exactly = 1) { transferStatusUpdater.updateTransferState(1, TransferState.FAILED) } verify(exactly = 1) { any().withConfig(any()) } + verify(exactly = 1) { clientProvider.getStorageTransferClient(any(), any()) } assertEquals(expectedResult, result) } @@ -128,7 +140,13 @@ internal class AbortMultiPartUploadWorkerTest { every { transferDB.getTransferRecordById(any()) }.answers { transferRecord } every { transferStatusUpdater.updateTransferState(any(), any()) }.answers { } - val worker = AbortMultiPartUploadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = AbortMultiPartUploadWorker( + clientProvider, + transferDB, + transferStatusUpdater, + context, + workerParameters + ) val result = worker.doWork() val expectedResult = @@ -157,7 +175,13 @@ internal class AbortMultiPartUploadWorkerTest { every { transferStatusUpdater.updateTransferState(any(), any()) }.answers { } every { transferStatusUpdater.updateOnError(any(), any()) }.answers { } - val worker = AbortMultiPartUploadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = AbortMultiPartUploadWorker( + clientProvider, + transferDB, + transferStatusUpdater, + context, + workerParameters + ) val result = worker.doWork() val expectedResult = diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorkerTest.kt b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorkerTest.kt index 9d7ad8114c..446015b0cc 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorkerTest.kt +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorkerTest.kt @@ -26,6 +26,8 @@ import aws.smithy.kotlin.runtime.content.ByteStream import aws.smithy.kotlin.runtime.content.fromFile import com.amplifyframework.storage.TransferState import com.amplifyframework.storage.s3.transfer.DownloadProgressListenerInterceptor +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferRecord import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -56,12 +58,14 @@ internal class DownloadWorkerTest { private lateinit var transferStatusUpdater: TransferStatusUpdater private lateinit var workerParameters: WorkerParameters private lateinit var downloadInterceptor: DownloadProgressListenerInterceptor + private lateinit var clientProvider: StorageTransferClientProvider @Before fun setup() { context = ApplicationProvider.getApplicationContext() workerParameters = mockk(WorkerParameters::class.java.name) s3Client = mockk(relaxed = true, relaxUnitFun = true) + clientProvider = mockk(S3StorageTransferClientProvider::class.java.name) mockkStatic(S3Client::withConfig) downloadInterceptor = mockk(relaxed = true, relaxUnitFun = true) transferDB = mockk(TransferDB::class.java.name) @@ -70,6 +74,7 @@ internal class DownloadWorkerTest { every { workerParameters.runAttemptCount }.answers { 1 } every { workerParameters.taskExecutor }.answers { ImmediateTaskExecutor() } every { s3Client.withConfig(any()) } returns s3Client + every { clientProvider.getStorageTransferClient(any(), any())}.answers { s3Client } } @After @@ -102,10 +107,11 @@ internal class DownloadWorkerTest { every { transferStatusUpdater.updateProgress(1, any(), any(), true, false) }.answers { } every { transferStatusUpdater.updateProgress(1, any(), any(), true, true) }.answers { } - val worker = DownloadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = DownloadWorker(clientProvider, transferDB, transferStatusUpdater, context, workerParameters) val result = worker.doWork() verify(atLeast = 1) { transferStatusUpdater.updateProgress(1, 10 * 1024 * 1024, 10 * 1024 * 1024, true, true) } + verify(exactly = 1) { clientProvider.getStorageTransferClient(any(), any()) } val expectedResult = ListenableWorker.Result.success(workDataOf(BaseTransferWorker.OUTPUT_TRANSFER_RECORD_ID to 1)) assertEquals(expectedResult, result) @@ -131,7 +137,7 @@ internal class DownloadWorkerTest { every { transferStatusUpdater.updateTransferState(1, TransferState.FAILED) }.answers { } every { transferStatusUpdater.updateOnError(1, any()) }.answers { } - val worker = DownloadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = DownloadWorker(clientProvider, transferDB, transferStatusUpdater, context, workerParameters) val result = worker.doWork() verify(exactly = 0) { diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorkerTest.kt b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorkerTest.kt index ea423d392e..852f660a95 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorkerTest.kt +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorkerTest.kt @@ -23,6 +23,8 @@ import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.model.CreateMultipartUploadResponse import aws.sdk.kotlin.services.s3.withConfig import com.amplifyframework.storage.TransferState +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferRecord import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -51,12 +53,14 @@ internal class InitiateMultiPartUploadTransferWorkerTest { private lateinit var transferDB: TransferDB private lateinit var transferStatusUpdater: TransferStatusUpdater private lateinit var workerParameters: WorkerParameters + private lateinit var clientProvider: StorageTransferClientProvider @Before fun setup() { context = ApplicationProvider.getApplicationContext() workerParameters = mockk(WorkerParameters::class.java.name) s3Client = mockk(relaxed = true) + clientProvider = mockk(S3StorageTransferClientProvider::class.java.name) mockkStatic(S3Client::withConfig) transferDB = mockk(TransferDB::class.java.name) transferStatusUpdater = mockk(TransferStatusUpdater::class.java.name) @@ -64,6 +68,7 @@ internal class InitiateMultiPartUploadTransferWorkerTest { every { workerParameters.runAttemptCount }.answers { 1 } every { workerParameters.taskExecutor }.answers { ImmediateTaskExecutor() } every { s3Client.withConfig(any()) } returns s3Client + every { clientProvider.getStorageTransferClient(any(), any())}.answers { s3Client } } @After @@ -89,7 +94,7 @@ internal class InitiateMultiPartUploadTransferWorkerTest { every { transferStatusUpdater.updateMultipartId(1, "upload_id") }.answers { } every { transferStatusUpdater.updateTransferState(any(), TransferState.IN_PROGRESS) }.answers { } val worker = InitiateMultiPartUploadTransferWorker( - s3Client, + clientProvider, transferDB, transferStatusUpdater, context, @@ -97,6 +102,7 @@ internal class InitiateMultiPartUploadTransferWorkerTest { ) val result = worker.doWork() verify(exactly = 1) { transferStatusUpdater.updateMultipartId(1, "upload_id") } + verify(exactly = 1) { clientProvider.getStorageTransferClient(any(), any()) } val output = workDataOf( BaseTransferWorker.MULTI_PART_UPLOAD_ID to "upload_id", BaseTransferWorker.TRANSFER_RECORD_ID to 1 @@ -124,7 +130,7 @@ internal class InitiateMultiPartUploadTransferWorkerTest { every { transferStatusUpdater.updateTransferState(any(), any()) }.answers { } val worker = InitiateMultiPartUploadTransferWorker( - s3Client, + clientProvider, transferDB, transferStatusUpdater, context,