diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/AWSS3StoragePlugin.java b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/AWSS3StoragePlugin.java index 01fde9ba41..090e8f7d85 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/AWSS3StoragePlugin.java +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/AWSS3StoragePlugin.java @@ -90,7 +90,8 @@ import com.amplifyframework.storage.s3.request.AWSS3StorageRemoveRequest; import com.amplifyframework.storage.s3.request.AWSS3StorageUploadRequest; import com.amplifyframework.storage.s3.service.AWSS3StorageService; -import com.amplifyframework.storage.s3.service.StorageService; +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider; +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider; import com.amplifyframework.storage.s3.transfer.TransferObserver; import com.amplifyframework.storage.s3.transfer.TransferRecord; import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater; @@ -126,20 +127,33 @@ public final class AWSS3StoragePlugin extends StoragePlugin { private static final int DEFAULT_URL_EXPIRATION_DAYS = 7; - private final StorageService.Factory storageServiceFactory; + private final AWSS3StorageService.Factory storageServiceFactory; private final ExecutorService executorService; - private final AuthCredentialsProvider authCredentialsProvider; + private AuthCredentialsProvider authCredentialsProvider; private final AWSS3StoragePluginConfiguration awsS3StoragePluginConfiguration; private AWSS3StorageService defaultStorageService; @SuppressWarnings("deprecation") private StorageAccessLevel defaultAccessLevel; private int defaultUrlExpiration; - private Map awsS3StorageServicesByBucketName = new HashMap<>(); + private final Map awsS3StorageServicesByBucketName = new HashMap<>(); private Context context; @SuppressLint("UnsafeOptInUsageError") private List configuredBuckets; + @SuppressLint("UnsafeOptInUsageError") + private StorageTransferClientProvider clientProvider = new S3StorageTransferClientProvider((region, bucketName) -> { + if (region != null && bucketName != null) { + StorageBucket bucket = StorageBucket.fromBucketInfo(new BucketInfo(region, bucketName)); + return getAWSS3StorageService((ResolvedStorageBucket) bucket).getClient(); + } + + if (region != null) { + return AWSS3StorageService.getS3Client(region, authCredentialsProvider); + } + return defaultStorageService.getClient(); + }); + /** * Constructs the AWS S3 Storage Plugin initializing the executor service. */ @@ -162,13 +176,14 @@ public AWSS3StoragePlugin(AWSS3StoragePluginConfiguration awsS3StoragePluginConf @VisibleForTesting AWSS3StoragePlugin(AuthCredentialsProvider authCredentialsProvider) { - this((context, region, bucket) -> + this((context, region, bucket, clientProvider) -> new AWSS3StorageService( context, region, bucket, authCredentialsProvider, - AWS_S3_STORAGE_PLUGIN_KEY + AWS_S3_STORAGE_PLUGIN_KEY, + clientProvider ), authCredentialsProvider, new AWSS3StoragePluginConfiguration.Builder().build()); @@ -177,13 +192,14 @@ public AWSS3StoragePlugin(AWSS3StoragePluginConfiguration awsS3StoragePluginConf @VisibleForTesting AWSS3StoragePlugin(AuthCredentialsProvider authCredentialsProvider, AWSS3StoragePluginConfiguration awss3StoragePluginConfiguration) { - this((context, region, bucket) -> + this((context, region, bucket, clientProvider) -> new AWSS3StorageService( context, region, bucket, authCredentialsProvider, - AWS_S3_STORAGE_PLUGIN_KEY + AWS_S3_STORAGE_PLUGIN_KEY, + clientProvider ), authCredentialsProvider, awss3StoragePluginConfiguration); @@ -191,7 +207,7 @@ public AWSS3StoragePlugin(AWSS3StoragePluginConfiguration awsS3StoragePluginConf @VisibleForTesting AWSS3StoragePlugin( - StorageService.Factory storageServiceFactory, + AWSS3StorageService.Factory storageServiceFactory, AuthCredentialsProvider authCredentialsProvider, AWSS3StoragePluginConfiguration awss3StoragePluginConfiguration ) { @@ -282,10 +298,11 @@ private void configure( ) throws StorageException { try { this.context = context; - this.defaultStorageService = (AWSS3StorageService) storageServiceFactory.create( + this.defaultStorageService = storageServiceFactory.create( context, region, - bucket.getBucketInfo().getName()); + bucket.getBucketInfo().getName(), + clientProvider); this.awsS3StorageServicesByBucketName.clear(); this.awsS3StorageServicesByBucketName.put(bucket.getBucketInfo().getName(), this.defaultStorageService); } catch (RuntimeException exception) { @@ -935,7 +952,8 @@ public StorageRemoveOperation remove( return operation; } - + + @SuppressLint("UnsafeOptInUsageError") @Override @SuppressWarnings("deprecation") public void getTransfer( @@ -951,18 +969,23 @@ public void getTransfer( transferRecord.getId(), defaultStorageService.getTransferManager().getTransferStatusUpdater(), transferRecord.getBucketName(), + transferRecord.getRegion(), transferRecord.getKey(), transferRecord.getFile(), null, transferRecord.getState() != null ? transferRecord.getState() : TransferState.UNKNOWN); TransferType transferType = transferRecord.getType(); + + AWSS3StorageService storageService + = getAwss3StorageServiceFromTransferRecord(onError, transferRecord); + switch (Objects.requireNonNull(transferType)) { case UPLOAD: if (transferRecord.getFile().startsWith(TransferStatusUpdater.TEMP_FILE_PREFIX)) { AWSS3StorageUploadInputStreamOperation operation = new AWSS3StorageUploadInputStreamOperation( transferId, - defaultStorageService, + storageService, executorService, authCredentialsProvider, awsS3StoragePluginConfiguration, @@ -973,7 +996,7 @@ public void getTransfer( AWSS3StorageUploadFileOperation operation = new AWSS3StorageUploadFileOperation( transferId, - defaultStorageService, + storageService, executorService, authCredentialsProvider, awsS3StoragePluginConfiguration, @@ -987,7 +1010,7 @@ public void getTransfer( downloadFileOperation = new AWSS3StorageDownloadFileOperation( transferId, new File(transferRecord.getFile()), - defaultStorageService, + storageService, executorService, authCredentialsProvider, awsS3StoragePluginConfiguration, @@ -1009,6 +1032,25 @@ public void getTransfer( }); } + private AWSS3StorageService getAwss3StorageServiceFromTransferRecord( + @NonNull Consumer onError, + TransferRecord transferRecord + ) { + AWSS3StorageService storageService = defaultStorageService; + if (transferRecord.getRegion() != null && transferRecord.getBucketName() != null) { + try { + BucketInfo bucketInfo = new BucketInfo( + transferRecord.getBucketName(), + transferRecord.getRegion()); + StorageBucket bucket = StorageBucket.fromBucketInfo(bucketInfo); + storageService = getStorageService(bucket); + } catch (StorageException exception) { + onError.accept(exception); + } + } + return storageService; + } + @NonNull @SuppressWarnings("deprecation") @Override @@ -1133,7 +1175,7 @@ private AWSS3StorageService getAWSS3StorageService(OutputsStorageBucket outputsS AWSS3StorageService service = awsS3StorageServicesByBucketName.get(bucketName); if (service == null) { String region = configuredBucket.getAwsRegion(); - service = (AWSS3StorageService) storageServiceFactory.create(context, region, bucketName); + service = storageServiceFactory.create(context, region, bucketName, clientProvider); awsS3StorageServicesByBucketName.put(bucketName, service); } @@ -1150,7 +1192,7 @@ private AWSS3StorageService getAWSS3StorageService(ResolvedStorageBucket resolve AWSS3StorageService service = awsS3StorageServicesByBucketName.get(bucketName); if (service == null) { String region = resolvedStorageBucket.getBucketInfo().getRegion(); - service = (AWSS3StorageService) storageServiceFactory.create(context, region, bucketName); + service = storageServiceFactory.create(context, region, bucketName, clientProvider); awsS3StorageServicesByBucketName.put(bucketName, service); } return service; diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/TransferOperations.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/TransferOperations.kt index 47cd602c41..99d15ea9d8 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/TransferOperations.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/TransferOperations.kt @@ -70,6 +70,7 @@ internal object TransferOperations { transferRecord.id, transferStatusUpdater, transferRecord.bucketName, + transferRecord.region, transferRecord.key, transferRecord.file, listener diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageService.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageService.kt index 7d9eef745d..8d35bba5e8 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageService.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/service/AWSS3StorageService.kt @@ -32,6 +32,7 @@ import com.amplifyframework.storage.StorageItem import com.amplifyframework.storage.options.SubpathStrategy import com.amplifyframework.storage.options.SubpathStrategy.Exclude import com.amplifyframework.storage.result.StorageListResult +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferManager import com.amplifyframework.storage.s3.transfer.TransferObserver import com.amplifyframework.storage.s3.transfer.TransferRecord @@ -46,7 +47,6 @@ import java.util.Date import kotlin.time.Duration.Companion.seconds import kotlin.time.ExperimentalTime import kotlinx.coroutines.runBlocking - /** * A representation of an S3 backend service endpoint. */ @@ -55,16 +55,27 @@ internal class AWSS3StorageService( private val awsRegion: String, private val s3BucketName: String, private val authCredentialsProvider: AuthCredentialsProvider, - private val awsS3StoragePluginKey: String + private val awsS3StoragePluginKey: String, + private val clientProvider: StorageTransferClientProvider ) : StorageService { + companion object { + @JvmStatic + fun getS3Client(region: String, authCredentialsProvider: AuthCredentialsProvider): S3Client { + return S3Client { + this.region = region + this.credentialsProvider = authCredentialsProvider + } + } + } + private var s3Client: S3Client = S3Client { region = awsRegion credentialsProvider = authCredentialsProvider } val transferManager: TransferManager = - TransferManager(context, s3Client, awsS3StoragePluginKey) + TransferManager(context, clientProvider, awsS3StoragePluginKey) /** * Generate pre-signed URL for an object. @@ -130,6 +141,7 @@ internal class AWSS3StorageService( return transferManager.download( transferId, s3BucketName, + awsRegion, serviceKey, file, useAccelerateEndpoint = useAccelerateEndpoint @@ -153,6 +165,7 @@ internal class AWSS3StorageService( return transferManager.upload( transferId, s3BucketName, + awsRegion, serviceKey, file, metadata, @@ -175,7 +188,7 @@ internal class AWSS3StorageService( metadata: ObjectMetadata, useAccelerateEndpoint: Boolean ): TransferObserver { - val uploadOptions = UploadOptions(s3BucketName, metadata) + val uploadOptions = UploadOptions(s3BucketName, awsRegion, metadata) return transferManager.upload(transferId, serviceKey, inputStream, uploadOptions, useAccelerateEndpoint) } @@ -420,4 +433,21 @@ internal class AWSS3StorageService( fun getClient(): S3Client { return s3Client } + + interface Factory { + /** + * Factory interface to instantiate [StorageService] object. + * + * @param context Android context + * @param region S3 bucket region + * @param bucketName Name of the bucket where the items are stored + * @return An instantiated storage service instance + */ + fun create( + context: Context, + region: String, + bucketName: String, + clientProvider: StorageTransferClientProvider + ): AWSS3StorageService + } } diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/S3StorageTransferClientProvider.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/S3StorageTransferClientProvider.kt new file mode 100644 index 0000000000..7f0e1ebe52 --- /dev/null +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/S3StorageTransferClientProvider.kt @@ -0,0 +1,27 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package com.amplifyframework.storage.s3.transfer + +import aws.sdk.kotlin.services.s3.S3Client +import com.amplifyframework.auth.AuthCredentialsProvider +import com.amplifyframework.storage.StorageException + +internal class S3StorageTransferClientProvider( + private val createS3Client: (region: String?, bucketName: String?) -> S3Client +) : StorageTransferClientProvider { + override fun getStorageTransferClient(region: String?, bucketName: String?): S3Client { + return createS3Client(region, bucketName) + } +} diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/StorageTransferClientProvider.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/StorageTransferClientProvider.kt new file mode 100644 index 0000000000..5b430e3fc5 --- /dev/null +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/StorageTransferClientProvider.kt @@ -0,0 +1,23 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package com.amplifyframework.storage.s3.transfer + +import aws.sdk.kotlin.services.s3.S3Client +import com.amplifyframework.annotations.InternalApiWarning + +@InternalApiWarning +internal interface StorageTransferClientProvider { + fun getStorageTransferClient(region: String?, bucketName: String?): S3Client +} diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDB.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDB.kt index 6413cf7c45..91195026a2 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDB.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDB.kt @@ -79,6 +79,7 @@ internal class TransferDB private constructor(context: Context) { fun insertMultipartUploadRecord( transferId: String, bucket: String, + region: String, key: String, file: File, fileOffset: Long, @@ -90,7 +91,7 @@ internal class TransferDB private constructor(context: Context) { ): Uri { val values: ContentValues = generateContentValuesForMultiPartUpload( transferId, - bucket, key, file, + bucket, region, key, file, fileOffset, partNumber, uploadId, bytesTotal, isLastPart, ObjectMetadata(), null, useAccelerateEndpoint @@ -114,6 +115,7 @@ internal class TransferDB private constructor(context: Context) { transferId: String, type: TransferType, bucket: String, + region: String, key: String, file: File?, cannedAcl: ObjectCannedAcl? = null, @@ -124,6 +126,7 @@ internal class TransferDB private constructor(context: Context) { transferId, type, bucket, + region, key, file, metadata, @@ -398,7 +401,7 @@ internal class TransferDB private constructor(context: Context) { */ fun getTransferRecordById(id: Int): TransferRecord? { var transferRecord: TransferRecord? = null - var c: Cursor? = null + var c: Cursor? try { c = queryTransferById(id) c?.use { @@ -415,7 +418,7 @@ internal class TransferDB private constructor(context: Context) { fun getTransferByTransferId(transferId: String): TransferRecord? { var transferRecord: TransferRecord? = null - var c: Cursor? = null + var c: Cursor? try { c = transferDBHelper.query(getTransferRecordIdUri(transferId)) c.use { @@ -560,6 +563,7 @@ internal class TransferDB private constructor(context: Context) { transferId: String, type: TransferType, bucket: String, + region: String, key: String, file: File, metadata: ObjectMetadata?, @@ -570,6 +574,7 @@ internal class TransferDB private constructor(context: Context) { transferId, type, bucket, + region, key, file, metadata, @@ -602,6 +607,7 @@ internal class TransferDB private constructor(context: Context) { fun generateContentValuesForMultiPartUpload( transferId: String, bucket: String?, + region: String?, key: String?, file: File, fileOffset: Long, @@ -611,13 +617,14 @@ internal class TransferDB private constructor(context: Context) { isLastPart: Int, metadata: ObjectMetadata?, cannedAcl: ObjectCannedAcl?, - useAccelerateEndpoint: Boolean + useAccelerateEndpoint: Boolean, ): ContentValues { val values = ContentValues() values.put(TransferTable.COLUMN_TRANSFER_ID, transferId) values.put(TransferTable.COLUMN_TYPE, TransferType.UPLOAD.toString()) values.put(TransferTable.COLUMN_STATE, TransferState.WAITING.toString()) values.put(TransferTable.COLUMN_BUCKET_NAME, bucket) + values.put(TransferTable.COLUMN_REGION, region) values.put(TransferTable.COLUMN_KEY, key) values.put(TransferTable.COLUMN_FILE, file.absolutePath) values.put(TransferTable.COLUMN_BYTES_CURRENT, 0L) @@ -723,6 +730,7 @@ internal class TransferDB private constructor(context: Context) { transferId: String, type: TransferType, bucket: String, + region: String, key: String, file: File?, metadata: ObjectMetadata?, @@ -734,6 +742,7 @@ internal class TransferDB private constructor(context: Context) { values.put(TransferTable.COLUMN_TYPE, type.toString()) values.put(TransferTable.COLUMN_STATE, TransferState.WAITING.toString()) values.put(TransferTable.COLUMN_BUCKET_NAME, bucket) + values.put(TransferTable.COLUMN_REGION, region) values.put(TransferTable.COLUMN_KEY, key) values.put(TransferTable.COLUMN_FILE, file?.absolutePath) values.put(TransferTable.COLUMN_BYTES_CURRENT, 0L) diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDBHelper.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDBHelper.kt index 211ea745e5..4403d51df4 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDBHelper.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferDBHelper.kt @@ -48,7 +48,7 @@ internal class TransferDBHelper(private val context: Context) : SQLiteOpenHelper // This represents the latest database version. // Update this when the database is being upgraded. - private const val DATABASE_VERSION = 9 + private const val DATABASE_VERSION = 10 private const val BASE_PATH = "transfers" private const val TRANSFERS = 10 private const val TRANSFER_ID = 20 diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferManager.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferManager.kt index 8a4bd9cb30..119aeeafb9 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferManager.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferManager.kt @@ -45,7 +45,7 @@ import kotlin.math.min */ internal class TransferManager( context: Context, - s3: S3Client, + clientProvider: StorageTransferClientProvider, private val pluginKey: String, private val workManager: WorkManager = WorkManager.getInstance(context) ) { @@ -71,7 +71,7 @@ internal class TransferManager( init { RouterWorker.workerFactories[pluginKey] = TransferWorkerFactory( transferDB, - s3, + clientProvider, transferStatusUpdater ) } @@ -93,6 +93,7 @@ internal class TransferManager( fun upload( transferId: String, bucket: String, + region: String, key: String, file: File, metadata: ObjectMetadata, @@ -101,12 +102,13 @@ internal class TransferManager( useAccelerateEndpoint: Boolean = false ): TransferObserver { val transferRecordId = if (shouldUploadInMultipart(file)) { - createMultipartUploadRecords(transferId, bucket, key, file, metadata, cannedAcl, useAccelerateEndpoint) + createMultipartUploadRecords(transferId, bucket, region, key, file, metadata, cannedAcl, useAccelerateEndpoint) } else { val uri = transferDB.insertSingleTransferRecord( transferId, TransferType.UPLOAD, bucket, + region, key, file, cannedAcl, @@ -147,6 +149,7 @@ internal class TransferManager( return upload( transferId, options.bucket, + options.region, key, file, options.objectMetadata, @@ -160,6 +163,7 @@ internal class TransferManager( fun download( transferId: String, bucket: String, + region: String, key: String, file: File, listener: TransferListener? = null, @@ -172,6 +176,7 @@ internal class TransferManager( transferId, TransferType.DOWNLOAD, bucket, + region, key, file, useAccelerateEndpoint = useAccelerateEndpoint @@ -246,6 +251,7 @@ internal class TransferManager( private fun createMultipartUploadRecords( transferId: String, bucket: String, + region: String, key: String, file: File, metadata: ObjectMetadata, @@ -263,6 +269,7 @@ internal class TransferManager( contentValues[0] = transferDB.generateContentValuesForMultiPartUpload( transferId, bucket, + region, key, file, fileOffset, @@ -279,6 +286,7 @@ internal class TransferManager( contentValues[partNum] = transferDB.generateContentValuesForMultiPartUpload( UUID.randomUUID().toString(), bucket, + region, key, file, fileOffset, diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferObserver.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferObserver.kt index 021e06a5b8..b6147f5cfe 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferObserver.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferObserver.kt @@ -24,6 +24,7 @@ internal data class TransferObserver @JvmOverloads constructor( val id: Int, private val transferStatusUpdater: TransferStatusUpdater, val bucket: String? = null, + val region: String? = null, val key: String? = null, val filePath: String? = null, private val listener: TransferListener? = null, diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferRecord.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferRecord.kt index c32a85c7ce..73d219b97d 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferRecord.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferRecord.kt @@ -36,6 +36,7 @@ internal data class TransferRecord( var type: TransferType? = null, var state: TransferState? = null, var bucketName: String? = null, + var region: String? = null, var key: String? = null, var versionId: String? = null, var file: String = "", @@ -80,6 +81,8 @@ internal data class TransferRecord( ) this.bucketName = c.getString(c.getColumnIndexOrThrow(TransferTable.COLUMN_BUCKET_NAME)) + this.region = + c.getString(c.getColumnIndexOrThrow(TransferTable.COLUMN_REGION)) this.key = c.getString(c.getColumnIndexOrThrow(TransferTable.COLUMN_KEY)) this.versionId = c.getString(c.getColumnIndexOrThrow(TransferTable.COLUMN_VERSION_ID)) diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferTable.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferTable.kt index 85b515062d..c80a95320a 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferTable.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/TransferTable.kt @@ -134,6 +134,8 @@ internal class TransferTable { const val COLUMN_USE_ACCELERATE_ENDPOINT = "useAccelerateEndpoint" + const val COLUMN_REGION = "region" + private const val TABLE_VERSION_2 = 2 private const val TABLE_VERSION_3 = 3 private const val TABLE_VERSION_4 = 4 @@ -142,6 +144,7 @@ internal class TransferTable { private const val TABLE_VERSION_7 = 7 private const val TABLE_VERSION_8 = 8 private const val TABLE_VERSION_9 = 9 + private const val TABLE_VERSION_10 = 10 // Database creation SQL statement const val DATABASE_CREATE = "create table $TABLE_TRANSFER (" + @@ -150,6 +153,7 @@ internal class TransferTable { "$COLUMN_TYPE text not null, " + "$COLUMN_STATE text not null, " + "$COLUMN_BUCKET_NAME text not null, " + + "$COLUMN_REGION text, " + "$COLUMN_KEY text not null," + "$COLUMN_VERSION_ID text, " + "$COLUMN_BYTES_TOTAL bigint, " + @@ -219,6 +223,9 @@ internal class TransferTable { if (TABLE_VERSION_9 in (oldVersion + 1)..newVersion) { addVersion9Columns(database) } + if (TABLE_VERSION_10 in (oldVersion + 1)..newVersion) { + addVersion10Columns(database) + } database.setTransactionSuccessful() database.endTransaction() } @@ -296,5 +303,11 @@ internal class TransferTable { "DEFAULT 0;" database.execSQL(addConnectionType) } + + private fun addVersion10Columns(database: SQLiteDatabase) { + val addRegion = "ALTER TABLE $TABLE_TRANSFER ADD COLUMN $COLUMN_REGION text " + + "DEFAULT null;" + database.execSQL(addRegion) + } } } diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/UploadOptions.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/UploadOptions.kt index f68453c757..635bc7eb7f 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/UploadOptions.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/UploadOptions.kt @@ -24,6 +24,7 @@ import com.amplifyframework.storage.ObjectMetadata internal data class UploadOptions @JvmOverloads constructor( val bucket: String, + val region: String, val objectMetadata: ObjectMetadata = ObjectMetadata(), val cannedAcl: ObjectCannedAcl? = null, val transferListener: TransferListener? = null diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorker.kt index 30aa2d20b9..3f9cb5efb9 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorker.kt @@ -20,6 +20,7 @@ import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.abortMultipartUpload import aws.sdk.kotlin.services.s3.withConfig import com.amplifyframework.storage.TransferState +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -27,7 +28,7 @@ import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater * Worker to abort pending multipart upload **/ internal class AbortMultiPartUploadWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -35,6 +36,7 @@ internal class AbortMultiPartUploadWorker( ) : BaseTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) { override suspend fun performWork(): Result { + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) return s3.withConfig { enableAccelerate = transferRecord.useAccelerateEndpoint == 1 }.abortMultipartUpload { diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/CompleteMultiPartUploadWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/CompleteMultiPartUploadWorker.kt index 8b48014a87..d7cadd5dd2 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/CompleteMultiPartUploadWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/CompleteMultiPartUploadWorker.kt @@ -20,6 +20,7 @@ import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.completeMultipartUpload import aws.sdk.kotlin.services.s3.model.CompletedMultipartUpload import aws.sdk.kotlin.services.s3.withConfig +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -27,7 +28,7 @@ import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater * Worker to complete multipart upload **/ internal class CompleteMultiPartUploadWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -36,6 +37,7 @@ internal class CompleteMultiPartUploadWorker( override suspend fun performWork(): Result { val completedParts = transferDB.queryPartETagsOfUpload(transferRecord.id) + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) return s3.withConfig { enableAccelerate = transferRecord.useAccelerateEndpoint == 1 }.completeMultipartUpload { diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorker.kt index 87a9db5612..9af6f8d70d 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorker.kt @@ -26,6 +26,7 @@ import aws.smithy.kotlin.runtime.io.SdkSource import aws.smithy.kotlin.runtime.io.buffer import com.amplifyframework.storage.s3.transfer.DownloadProgressListener import com.amplifyframework.storage.s3.transfer.DownloadProgressListenerInterceptor +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater import java.io.BufferedOutputStream @@ -40,7 +41,7 @@ import kotlinx.coroutines.withContext * Worker to perform download file task. */ internal class DownloadWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -50,6 +51,7 @@ internal class DownloadWorker( private lateinit var downloadProgressListener: DownloadProgressListener private val defaultBufferSize = 8192L override suspend fun performWork(): Result { + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) s3.withConfig { enableAccelerate = transferRecord.useAccelerateEndpoint == 1 } diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorker.kt index ec7be7cb35..b3c013b4bc 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorker.kt @@ -21,6 +21,7 @@ import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.createMultipartUpload import aws.sdk.kotlin.services.s3.withConfig import com.amplifyframework.storage.TransferState +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -28,7 +29,7 @@ import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater * Worker to initiate multipart upload **/ internal class InitiateMultiPartUploadTransferWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -36,6 +37,7 @@ internal class InitiateMultiPartUploadTransferWorker( ) : BaseTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) { override suspend fun performWork(): Result { + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) transferStatusUpdater.updateTransferState(transferRecord.id, TransferState.IN_PROGRESS) val putObjectRequest = createPutObjectRequest(transferRecord, null) return s3.withConfig { diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/PartUploadTransferWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/PartUploadTransferWorker.kt index d145786b6d..b7a0f6760d 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/PartUploadTransferWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/PartUploadTransferWorker.kt @@ -22,6 +22,7 @@ import aws.sdk.kotlin.services.s3.withConfig import aws.smithy.kotlin.runtime.content.asByteStream import com.amplifyframework.storage.TransferState import com.amplifyframework.storage.s3.transfer.PartUploadProgressListener +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater import com.amplifyframework.storage.s3.transfer.UploadProgressListenerInterceptor @@ -33,7 +34,7 @@ import kotlinx.coroutines.isActive * Worker to upload a part for multipart upload **/ internal class PartUploadTransferWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -51,6 +52,7 @@ internal class PartUploadTransferWorker( transferStatusUpdater.updateTransferState(transferRecord.mainUploadId, TransferState.IN_PROGRESS) multiPartUploadId = inputData.keyValueMap[MULTI_PART_UPLOAD_ID] as String partUploadProgressListener = PartUploadProgressListener(transferRecord, transferStatusUpdater) + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) return s3.withConfig { interceptors += UploadProgressListenerInterceptor(partUploadProgressListener) enableAccelerate = transferRecord.useAccelerateEndpoint == 1 diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/SinglePartUploadWorker.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/SinglePartUploadWorker.kt index 21de0db80c..515c36befc 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/SinglePartUploadWorker.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/SinglePartUploadWorker.kt @@ -22,13 +22,14 @@ import android.content.Context import androidx.work.WorkerParameters import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.withConfig +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater import com.amplifyframework.storage.s3.transfer.UploadProgressListener import com.amplifyframework.storage.s3.transfer.UploadProgressListenerInterceptor internal class SinglePartUploadWorker( - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferDB: TransferDB, private val transferStatusUpdater: TransferStatusUpdater, context: Context, @@ -40,6 +41,7 @@ internal class SinglePartUploadWorker( override suspend fun performWork(): Result { uploadProgressListener = UploadProgressListener(transferRecord, transferStatusUpdater) val putObjectRequest = createPutObjectRequest(transferRecord, uploadProgressListener) + val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName) return s3.withConfig { interceptors += UploadProgressListenerInterceptor(uploadProgressListener) enableAccelerate = transferRecord.useAccelerateEndpoint == 1 diff --git a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/TransferWorkerFactory.kt b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/TransferWorkerFactory.kt index 88c0dc19f4..f491d1ee95 100644 --- a/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/TransferWorkerFactory.kt +++ b/aws-storage-s3/src/main/java/com/amplifyframework/storage/s3/transfer/worker/TransferWorkerFactory.kt @@ -18,6 +18,7 @@ import android.content.Context import androidx.work.WorkerFactory import androidx.work.WorkerParameters import aws.sdk.kotlin.services.s3.S3Client +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -26,7 +27,7 @@ import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater **/ internal class TransferWorkerFactory( private val transferDB: TransferDB, - private val s3: S3Client, + private val clientProvider: StorageTransferClientProvider, private val transferStatusUpdater: TransferStatusUpdater ) : WorkerFactory() { override fun createWorker( @@ -37,7 +38,7 @@ internal class TransferWorkerFactory( when (workerClassName) { DownloadWorker::class.java.name -> return DownloadWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -45,7 +46,7 @@ internal class TransferWorkerFactory( ) SinglePartUploadWorker::class.java.name -> return SinglePartUploadWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -53,7 +54,7 @@ internal class TransferWorkerFactory( ) InitiateMultiPartUploadTransferWorker::class.java.name -> return InitiateMultiPartUploadTransferWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -61,7 +62,7 @@ internal class TransferWorkerFactory( ) PartUploadTransferWorker::class.java.name -> return PartUploadTransferWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -69,7 +70,7 @@ internal class TransferWorkerFactory( ) CompleteMultiPartUploadWorker::class.java.name -> return CompleteMultiPartUploadWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, @@ -77,7 +78,7 @@ internal class TransferWorkerFactory( ) AbortMultiPartUploadWorker::class.java.name -> return AbortMultiPartUploadWorker( - s3, + clientProvider, transferDB, transferStatusUpdater, appContext, diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/AWSS3StoragePluginTest.kt b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/AWSS3StoragePluginTest.kt index 5d15f2d3af..da4d54b368 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/AWSS3StoragePluginTest.kt +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/AWSS3StoragePluginTest.kt @@ -19,7 +19,6 @@ import com.amplifyframework.storage.BucketInfo import com.amplifyframework.storage.StorageBucket import com.amplifyframework.storage.StorageException import com.amplifyframework.storage.s3.service.AWSS3StorageService -import com.amplifyframework.storage.s3.service.StorageService import com.amplifyframework.testutils.configuration.amplifyOutputsData import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldNotBe @@ -30,8 +29,8 @@ import org.junit.Test class AWSS3StoragePluginTest { - private val storageServiceFactory = mockk { - every { create(any(), any(), any()) } returns mockk() + private val storageServiceFactory = mockk { + every { create(any(), any(), any(), any()) } returns mockk() } private val plugin = AWSS3StoragePlugin( @@ -52,7 +51,7 @@ class AWSS3StoragePluginTest { plugin.configure(data, mockk()) verify { - storageServiceFactory.create(any(), "test-region", "test-bucket") + storageServiceFactory.create(any(), "test-region", "test-bucket", any()) } } diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/StorageComponentTest.java b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/StorageComponentTest.java index c1ac9ca541..3e238acced 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/StorageComponentTest.java +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/StorageComponentTest.java @@ -32,6 +32,7 @@ import com.amplifyframework.storage.s3.configuration.AWSS3StoragePluginConfiguration; import com.amplifyframework.storage.s3.service.AWSS3StorageService; import com.amplifyframework.storage.s3.service.StorageService; +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider; import com.amplifyframework.storage.s3.transfer.TransferListener; import com.amplifyframework.storage.s3.transfer.TransferObserver; import com.amplifyframework.testutils.Await; @@ -76,6 +77,7 @@ public final class StorageComponentTest { private StorageCategory storage; private StorageService storageService; + private StorageTransferClientProvider clientProvider; /** * Sets up Storage category by registering a mock AWSS3StoragePlugin @@ -88,7 +90,8 @@ public final class StorageComponentTest { public void setup() throws AmplifyException { this.storage = new StorageCategory(); this.storageService = mock(AWSS3StorageService.class); - StorageService.Factory storageServiceFactory = (context, region, bucket) -> storageService; + AWSS3StorageService.Factory storageServiceFactory + = (context, region, bucket, clientProvider) -> (AWSS3StorageService) storageService; AuthCredentialsProvider cognitoAuthProvider = mock(AuthCredentialsProvider.class); doReturn(RandomString.string()).when(cognitoAuthProvider).getIdentityId(null); this.storage.addPlugin(new AWSS3StoragePlugin(storageServiceFactory, diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorkerTest.kt b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorkerTest.kt index c3097d71c0..5772881318 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorkerTest.kt +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/AbortMultiPartUploadWorkerTest.kt @@ -24,6 +24,8 @@ import aws.sdk.kotlin.services.s3.model.AbortMultipartUploadRequest import aws.sdk.kotlin.services.s3.model.AbortMultipartUploadResponse import aws.sdk.kotlin.services.s3.withConfig import com.amplifyframework.storage.TransferState +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferRecord import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -52,6 +54,7 @@ internal class AbortMultiPartUploadWorkerTest { private lateinit var transferDB: TransferDB private lateinit var transferStatusUpdater: TransferStatusUpdater private lateinit var workerParameters: WorkerParameters + private lateinit var clientProvider: StorageTransferClientProvider @Before fun setup() { @@ -59,6 +62,7 @@ internal class AbortMultiPartUploadWorkerTest { context = ApplicationProvider.getApplicationContext() workerParameters = mockk(WorkerParameters::class.java.name) s3Client = spyk(recordPrivateCalls = true) + clientProvider = mockk(S3StorageTransferClientProvider::class.java.name) mockkStatic(S3Client::withConfig) transferDB = mockk(TransferDB::class.java.name) transferStatusUpdater = mockk(TransferStatusUpdater::class.java.name) @@ -66,6 +70,7 @@ internal class AbortMultiPartUploadWorkerTest { every { workerParameters.runAttemptCount }.answers { 1 } every { workerParameters.taskExecutor }.answers { ImmediateTaskExecutor() } every { any().withConfig(any()) }.answers { s3Client } + every { clientProvider.getStorageTransferClient(any(), any()) }.answers { s3Client } } @After @@ -95,13 +100,20 @@ internal class AbortMultiPartUploadWorkerTest { every { transferDB.getTransferRecordById(any()) }.answers { transferRecord } every { transferStatusUpdater.updateTransferState(any(), any()) }.answers { } - val worker = AbortMultiPartUploadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = AbortMultiPartUploadWorker( + clientProvider, + transferDB, + transferStatusUpdater, + context, + workerParameters + ) val result = worker.doWork() val expectedResult = ListenableWorker.Result.success(workDataOf(BaseTransferWorker.OUTPUT_TRANSFER_RECORD_ID to 1)) verify(exactly = 1) { transferStatusUpdater.updateTransferState(1, TransferState.FAILED) } verify(exactly = 1) { any().withConfig(any()) } + verify(exactly = 1) { clientProvider.getStorageTransferClient(any(), any()) } assertEquals(expectedResult, result) } @@ -128,7 +140,13 @@ internal class AbortMultiPartUploadWorkerTest { every { transferDB.getTransferRecordById(any()) }.answers { transferRecord } every { transferStatusUpdater.updateTransferState(any(), any()) }.answers { } - val worker = AbortMultiPartUploadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = AbortMultiPartUploadWorker( + clientProvider, + transferDB, + transferStatusUpdater, + context, + workerParameters + ) val result = worker.doWork() val expectedResult = @@ -157,7 +175,13 @@ internal class AbortMultiPartUploadWorkerTest { every { transferStatusUpdater.updateTransferState(any(), any()) }.answers { } every { transferStatusUpdater.updateOnError(any(), any()) }.answers { } - val worker = AbortMultiPartUploadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = AbortMultiPartUploadWorker( + clientProvider, + transferDB, + transferStatusUpdater, + context, + workerParameters + ) val result = worker.doWork() val expectedResult = diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorkerTest.kt b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorkerTest.kt index 9d7ad8114c..446015b0cc 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorkerTest.kt +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/DownloadWorkerTest.kt @@ -26,6 +26,8 @@ import aws.smithy.kotlin.runtime.content.ByteStream import aws.smithy.kotlin.runtime.content.fromFile import com.amplifyframework.storage.TransferState import com.amplifyframework.storage.s3.transfer.DownloadProgressListenerInterceptor +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferRecord import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -56,12 +58,14 @@ internal class DownloadWorkerTest { private lateinit var transferStatusUpdater: TransferStatusUpdater private lateinit var workerParameters: WorkerParameters private lateinit var downloadInterceptor: DownloadProgressListenerInterceptor + private lateinit var clientProvider: StorageTransferClientProvider @Before fun setup() { context = ApplicationProvider.getApplicationContext() workerParameters = mockk(WorkerParameters::class.java.name) s3Client = mockk(relaxed = true, relaxUnitFun = true) + clientProvider = mockk(S3StorageTransferClientProvider::class.java.name) mockkStatic(S3Client::withConfig) downloadInterceptor = mockk(relaxed = true, relaxUnitFun = true) transferDB = mockk(TransferDB::class.java.name) @@ -70,6 +74,7 @@ internal class DownloadWorkerTest { every { workerParameters.runAttemptCount }.answers { 1 } every { workerParameters.taskExecutor }.answers { ImmediateTaskExecutor() } every { s3Client.withConfig(any()) } returns s3Client + every { clientProvider.getStorageTransferClient(any(), any())}.answers { s3Client } } @After @@ -102,10 +107,11 @@ internal class DownloadWorkerTest { every { transferStatusUpdater.updateProgress(1, any(), any(), true, false) }.answers { } every { transferStatusUpdater.updateProgress(1, any(), any(), true, true) }.answers { } - val worker = DownloadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = DownloadWorker(clientProvider, transferDB, transferStatusUpdater, context, workerParameters) val result = worker.doWork() verify(atLeast = 1) { transferStatusUpdater.updateProgress(1, 10 * 1024 * 1024, 10 * 1024 * 1024, true, true) } + verify(exactly = 1) { clientProvider.getStorageTransferClient(any(), any()) } val expectedResult = ListenableWorker.Result.success(workDataOf(BaseTransferWorker.OUTPUT_TRANSFER_RECORD_ID to 1)) assertEquals(expectedResult, result) @@ -131,7 +137,7 @@ internal class DownloadWorkerTest { every { transferStatusUpdater.updateTransferState(1, TransferState.FAILED) }.answers { } every { transferStatusUpdater.updateOnError(1, any()) }.answers { } - val worker = DownloadWorker(s3Client, transferDB, transferStatusUpdater, context, workerParameters) + val worker = DownloadWorker(clientProvider, transferDB, transferStatusUpdater, context, workerParameters) val result = worker.doWork() verify(exactly = 0) { diff --git a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorkerTest.kt b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorkerTest.kt index ea423d392e..852f660a95 100644 --- a/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorkerTest.kt +++ b/aws-storage-s3/src/test/java/com/amplifyframework/storage/s3/transfer/worker/InitiateMultiPartUploadTransferWorkerTest.kt @@ -23,6 +23,8 @@ import aws.sdk.kotlin.services.s3.S3Client import aws.sdk.kotlin.services.s3.model.CreateMultipartUploadResponse import aws.sdk.kotlin.services.s3.withConfig import com.amplifyframework.storage.TransferState +import com.amplifyframework.storage.s3.transfer.S3StorageTransferClientProvider +import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider import com.amplifyframework.storage.s3.transfer.TransferDB import com.amplifyframework.storage.s3.transfer.TransferRecord import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater @@ -51,12 +53,14 @@ internal class InitiateMultiPartUploadTransferWorkerTest { private lateinit var transferDB: TransferDB private lateinit var transferStatusUpdater: TransferStatusUpdater private lateinit var workerParameters: WorkerParameters + private lateinit var clientProvider: StorageTransferClientProvider @Before fun setup() { context = ApplicationProvider.getApplicationContext() workerParameters = mockk(WorkerParameters::class.java.name) s3Client = mockk(relaxed = true) + clientProvider = mockk(S3StorageTransferClientProvider::class.java.name) mockkStatic(S3Client::withConfig) transferDB = mockk(TransferDB::class.java.name) transferStatusUpdater = mockk(TransferStatusUpdater::class.java.name) @@ -64,6 +68,7 @@ internal class InitiateMultiPartUploadTransferWorkerTest { every { workerParameters.runAttemptCount }.answers { 1 } every { workerParameters.taskExecutor }.answers { ImmediateTaskExecutor() } every { s3Client.withConfig(any()) } returns s3Client + every { clientProvider.getStorageTransferClient(any(), any())}.answers { s3Client } } @After @@ -89,7 +94,7 @@ internal class InitiateMultiPartUploadTransferWorkerTest { every { transferStatusUpdater.updateMultipartId(1, "upload_id") }.answers { } every { transferStatusUpdater.updateTransferState(any(), TransferState.IN_PROGRESS) }.answers { } val worker = InitiateMultiPartUploadTransferWorker( - s3Client, + clientProvider, transferDB, transferStatusUpdater, context, @@ -97,6 +102,7 @@ internal class InitiateMultiPartUploadTransferWorkerTest { ) val result = worker.doWork() verify(exactly = 1) { transferStatusUpdater.updateMultipartId(1, "upload_id") } + verify(exactly = 1) { clientProvider.getStorageTransferClient(any(), any()) } val output = workDataOf( BaseTransferWorker.MULTI_PART_UPLOAD_ID to "upload_id", BaseTransferWorker.TRANSFER_RECORD_ID to 1 @@ -124,7 +130,7 @@ internal class InitiateMultiPartUploadTransferWorkerTest { every { transferStatusUpdater.updateTransferState(any(), any()) }.answers { } val worker = InitiateMultiPartUploadTransferWorker( - s3Client, + clientProvider, transferDB, transferStatusUpdater, context,