From 7c2b7abe8ea1d10b3537cb174f7dd1cf220da37d Mon Sep 17 00:00:00 2001 From: Di Wu Date: Thu, 31 Aug 2023 10:26:00 -0700 Subject: [PATCH] change interceptors applying logic --- .../AWSAPICategoryPluginConfiguration.swift | 27 +++-- .../AWSAPIEndpointInterceptors.swift | 14 ++- .../AuthTokenURLRequestInterceptor.swift | 4 +- .../IAMURLRequestInterceptor.swift | 1 - .../Operation/AWSGraphQLOperation.swift | 81 ++++++++++----- .../Operation/AWSRESTOperation.swift | 98 +++++++++++++------ .../Utils/GraphQLOperationRequestUtils.swift | 4 +- .../Utils/RESTOperationRequestUtils.swift | 30 ++---- .../Support/Utils/Result+Async.swift | 20 ++++ .../RESTWithIAMIntegrationTests.swift | 1 - ...egoryPlugin+InterceptorBehaviorTests.swift | 8 +- ...SAPICategoryPluginConfigurationTests.swift | 22 +++-- .../AWSAPIEndpointInterceptorsTests.swift | 19 ++-- .../AuthTokenURLRequestInterceptorTests.swift | 6 +- .../Support/Utils/RESTRequestUtilsTests.swift | 13 --- .../Support/Utils/Result+AsyncTests.swift | 53 ++++++++++ 16 files changed, 265 insertions(+), 136 deletions(-) create mode 100644 AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Result+Async.swift create mode 100644 AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Result+AsyncTests.swift diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPICategoryPluginConfiguration.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPICategoryPluginConfiguration.swift index 9a4d894c41..b4aa82783b 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPICategoryPluginConfiguration.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPICategoryPluginConfiguration.swift @@ -84,22 +84,19 @@ public struct AWSAPICategoryPluginConfiguration { interceptors[apiName]?.addInterceptor(interceptor) } - /// Returns all the interceptors registered for `apiName` API endpoint + /// Returns all the customer defined interceptors registered for `apiName` API endpoint /// - Parameter apiName: API endpoint name /// - Returns: request interceptors - internal func interceptorsForEndpoint(named apiName: APIEndpointName) -> [URLRequestInterceptor] { - guard let interceptorsConfig = interceptors[apiName] else { - return [] - } - return interceptorsConfig.interceptors + internal func interceptorsForEndpoint(named apiName: APIEndpointName) -> AWSAPIEndpointInterceptors? { + return interceptors[apiName] } - /// Returns interceptors for the provided endpointConfig + /// Returns customer defined interceptors for the provided endpointConfig /// - Parameters: /// - endpointConfig: endpoint configuration /// - Throws: PluginConfigurationError in case of failure building an instance of AWSAuthorizationConfiguration /// - Returns: An array of URLRequestInterceptor - internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig) throws -> [URLRequestInterceptor] { + internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig) -> AWSAPIEndpointInterceptors? { return interceptorsForEndpoint(named: endpointConfig.name) } @@ -109,8 +106,10 @@ public struct AWSAPICategoryPluginConfiguration { /// - authType: overrides the registered auth interceptor /// - Throws: PluginConfigurationError in case of failure building an instance of AWSAuthorizationConfiguration /// - Returns: An array of URLRequestInterceptor - internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig, - authType: AWSAuthorizationType) throws -> [URLRequestInterceptor] { + internal func interceptorsForEndpoint( + withConfig endpointConfig: EndpointConfig, + authType: AWSAuthorizationType + ) throws -> AWSAPIEndpointInterceptors? { guard let apiAuthProviderFactory = self.apiAuthProviderFactory else { return interceptorsForEndpoint(named: endpointConfig.name) @@ -126,12 +125,10 @@ public struct AWSAPICategoryPluginConfiguration { authConfiguration: authConfiguration) // retrieve current interceptors and replace auth interceptor - let currentInterceptors = interceptorsForEndpoint(named: endpointConfig.name).filter { - !isAuthInterceptor($0) - } - config.interceptors.append(contentsOf: currentInterceptors) + let currentInterceptors = interceptorsForEndpoint(named: endpointConfig.name) + config.interceptors.append(contentsOf: currentInterceptors?.interceptors ?? []) - return config.interceptors + return config } // MARK: Private diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPIEndpointInterceptors.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPIEndpointInterceptors.swift index 13f6ba16e8..1e41e22894 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPIEndpointInterceptors.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPIEndpointInterceptors.swift @@ -16,6 +16,10 @@ struct AWSAPIEndpointInterceptors { let apiAuthProviderFactory: APIAuthProviderFactory let authService: AWSAuthServiceBehavior? + var amplifyInterceptors: [URLRequestInterceptor] = [] + + var checksumInterceptors: [URLRequestInterceptor] = [] + var interceptors: [URLRequestInterceptor] = [] init(endpointName: APIEndpointName, @@ -42,7 +46,7 @@ struct AWSAPIEndpointInterceptors { case .apiKey(let apiKeyConfig): let provider = BasicAPIKeyProvider(apiKey: apiKeyConfig.apiKey) let interceptor = APIKeyURLRequestInterceptor(apiKeyProvider: provider) - addInterceptor(interceptor) + amplifyInterceptors.append(interceptor) case .awsIAM(let iamConfig): guard let authService = authService else { throw PluginError.pluginConfigurationError("AuthService is not set for IAM", @@ -52,7 +56,7 @@ struct AWSAPIEndpointInterceptors { let interceptor = IAMURLRequestInterceptor(iamCredentialsProvider: provider, region: iamConfig.region, endpointType: endpointType) - addInterceptor(interceptor) + checksumInterceptors.append(interceptor) case .amazonCognitoUserPools: guard let authService = authService else { throw PluginError.pluginConfigurationError("AuthService not set for cognito user pools", @@ -60,7 +64,7 @@ struct AWSAPIEndpointInterceptors { } let provider = BasicUserPoolTokenProvider(authService: authService) let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: provider) - addInterceptor(interceptor) + amplifyInterceptors.append(interceptor) case .openIDConnect: guard let oidcAuthProvider = apiAuthProviderFactory.oidcAuthProvider() else { throw PluginError.pluginConfigurationError("AuthService not set for OIDC", @@ -68,7 +72,7 @@ struct AWSAPIEndpointInterceptors { } let wrappedAuthProvider = AuthTokenProviderWrapper(tokenAuthProvider: oidcAuthProvider) let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: wrappedAuthProvider) - addInterceptor(interceptor) + amplifyInterceptors.append(interceptor) case .function: guard let functionAuthProvider = apiAuthProviderFactory.functionAuthProvider() else { throw PluginError.pluginConfigurationError("AuthService not set for function auth", @@ -76,7 +80,7 @@ struct AWSAPIEndpointInterceptors { } let wrappedAuthProvider = AuthTokenProviderWrapper(tokenAuthProvider: functionAuthProvider) let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: wrappedAuthProvider) - addInterceptor(interceptor) + amplifyInterceptors.append(interceptor) } } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/AuthTokenURLRequestInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/AuthTokenURLRequestInterceptor.swift index 0175364b7a..d7392748e8 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/AuthTokenURLRequestInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/AuthTokenURLRequestInterceptor.swift @@ -32,9 +32,7 @@ struct AuthTokenURLRequestInterceptor: URLRequestInterceptor { mutableRequest.setValue(amzDate, forHTTPHeaderField: URLRequestConstants.Header.xAmzDate) - mutableRequest.setValue(URLRequestConstants.ContentType.applicationJson, - forHTTPHeaderField: URLRequestConstants.Header.contentType) - mutableRequest.setValue(userAgent, + mutableRequest.addValue(userAgent, forHTTPHeaderField: URLRequestConstants.Header.userAgent) let token: String diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/IAMURLRequestInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/IAMURLRequestInterceptor.swift index 013cac0459..0a9b53fe93 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/IAMURLRequestInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/IAMURLRequestInterceptor.swift @@ -36,7 +36,6 @@ struct IAMURLRequestInterceptor: URLRequestInterceptor { throw APIError.unknown("Could not get host from mutable request", "") } - request.setValue(URLRequestConstants.ContentType.applicationJson, forHTTPHeaderField: URLRequestConstants.Header.contentType) request.setValue(host, forHTTPHeaderField: "host") request.setValue(userAgent, forHTTPHeaderField: URLRequestConstants.Header.userAgent) diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift index 2f6ed4ba27..94b24a765a 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift @@ -56,7 +56,7 @@ final public class AWSGraphQLOperation: GraphQLOperation { // Retrieve endpoint configuration let endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig - let requestInterceptors: [URLRequestInterceptor] + let requestInterceptors: AWSAPIEndpointInterceptors? do { endpointConfig = try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .graphQL) @@ -66,7 +66,7 @@ final public class AWSGraphQLOperation: GraphQLOperation { requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig, authType: authType) } else { - requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig) + requestInterceptors = pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig) } } catch let error as APIError { dispatch(result: .failure(error)) @@ -101,34 +101,67 @@ final public class AWSGraphQLOperation: GraphQLOperation { // Create request let urlRequest = GraphQLOperationRequestUtils.constructRequest(with: endpointConfig.baseURL, requestPayload: requestPayload) - + let amplifyInterceptors = requestInterceptors?.amplifyInterceptors ?? [] + let customerInterceptors = requestInterceptors?.interceptors ?? [] + let checksumInterceptors = requestInterceptors?.checksumInterceptors ?? [] Task { - // Intercept request - var finalRequest = urlRequest - for interceptor in requestInterceptors { - do { - finalRequest = try await interceptor.intercept(finalRequest) - } catch let error as APIError { - dispatch(result: .failure(error)) - cancel() - } catch { - dispatch(result: .failure(APIError.operationError("Failed to intercept request fully.", - "Something wrong with the interceptor", - error))) - cancel() + var finalResult: Result = .success(urlRequest) + // apply amplify interceptors + for interceptor in amplifyInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) + } + } + + // there is no customer headers for GraphQLOperationRequest + + // apply customer interceptors + for interceptor in customerInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) + } + } + + // apply checksum interceptor + for interceptor in checksumInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) } } - if isCancelled { - finish() - return + switch finalResult { + case .success(let finalRequest): + if isCancelled { + finish() + return + } + + // Begin network task + Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)") + let task = session.dataTaskBehavior(with: finalRequest) + mapper.addPair(operation: self, task: task) + task.resume() + case .failure(let error): + dispatch(result: .failure(error)) + cancel() } + } + } - // Begin network task - Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)") - let task = session.dataTaskBehavior(with: finalRequest) - mapper.addPair(operation: self, task: task) - task.resume() + private func applyInterceptor(_ interceptor: URLRequestInterceptor, request: URLRequest) async -> Result { + do { + return .success(try await interceptor.intercept(request)) + } catch let error as APIError { + return .failure(error) + } catch { + return .failure( + APIError.operationError( + "Failed to intercept request fully.", + "Something wrong with the interceptor", + error + ) + ) } } + } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSRESTOperation.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSRESTOperation.swift index ca2324f4a5..6beb8b8b79 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSRESTOperation.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSRESTOperation.swift @@ -60,10 +60,15 @@ final public class AWSRESTOperation: AmplifyOperation< // Retrieve endpoint configuration let endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig - let requestInterceptors: [URLRequestInterceptor] + let amplifyInterceptors: [URLRequestInterceptor] + let customerInterceptors: [URLRequestInterceptor] + let checksumInterceptors: [URLRequestInterceptor] do { endpointConfig = try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .rest) - requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig) + let interceptorConfig = pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig) + amplifyInterceptors = interceptorConfig?.amplifyInterceptors ?? [] + customerInterceptors = interceptorConfig?.interceptors ?? [] + checksumInterceptors = interceptorConfig?.checksumInterceptors ?? [] } catch let error as APIError { dispatch(result: .failure(error)) finish() @@ -94,41 +99,76 @@ final public class AWSRESTOperation: AmplifyOperation< } // Construct URL Request with url and request body - let urlRequest = RESTOperationRequestUtils.constructURLRequest(with: url, - operationType: request.operationType, - headers: request.headers, - requestPayload: request.body) + let urlRequest = RESTOperationRequestUtils.constructURLRequest( + with: url, + operationType: request.operationType, + requestPayload: request.body + ) Task { - // Intercept request - var finalRequest = urlRequest - for interceptor in requestInterceptors { - do { - finalRequest = try await interceptor.intercept(finalRequest) - } catch let error as APIError { - dispatch(result: .failure(error)) - cancel() - } catch { - dispatch(result: .failure(APIError.operationError("Failed to intercept request fully.", - "Something wrong with the interceptor", - error))) - cancel() + var finalResult: Result = .success(urlRequest) + // apply amplify interceptors + for interceptor in amplifyInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) } } - // The headers from the request object should override any of the same header modifications done by the intercepters above - finalRequest = RESTOperationRequestUtils.applyCustomizeRequestHeaders(request.headers, on: finalRequest) + // apply customer headers + finalResult = finalResult.map { urlRequest in + var mutableRequest = urlRequest + for (key, value) in request.headers ?? [:] { + mutableRequest.setValue(value, forHTTPHeaderField: key) + } + return mutableRequest + } + + // apply customer interceptors + for interceptor in customerInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) + } + } + + // apply checksum interceptor + for interceptor in checksumInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) + } + } - if isCancelled { - finish() - return + switch finalResult { + case .success(let finalRequest): + if isCancelled { + finish() + return + } + + // Begin network task + Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)") + let task = session.dataTaskBehavior(with: finalRequest) + mapper.addPair(operation: self, task: task) + task.resume() + case .failure(let error): + dispatch(result: .failure(error)) + cancel() } + } + } - // Begin network task - Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)") - let task = session.dataTaskBehavior(with: finalRequest) - mapper.addPair(operation: self, task: task) - task.resume() + private func applyInterceptor(_ interceptor: URLRequestInterceptor, request: URLRequest) async -> Result { + do { + return .success(try await interceptor.intercept(request)) + } catch let error as APIError { + return .failure(error) + } catch { + return .failure( + APIError.operationError( + "Failed to intercept request fully.", + "Something wrong with the interceptor", + error + ) + ) } } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLOperationRequestUtils.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLOperationRequestUtils.swift index 18f695e68b..80e7bc74df 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLOperationRequestUtils.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLOperationRequestUtils.swift @@ -22,8 +22,8 @@ class GraphQLOperationRequestUtils { // Construct a graphQL specific HTTP POST request with the request payload static func constructRequest(with baseUrl: URL, requestPayload: Data) -> URLRequest { var baseRequest = URLRequest(url: baseUrl) - let headers = ["content-type": "application/json", "Cache-Control": "no-store"] - baseRequest.allHTTPHeaderFields = headers + baseRequest.setValue("application/json", forHTTPHeaderField: "content-type") + baseRequest.setValue("no-store", forHTTPHeaderField: "cache-control") baseRequest.httpMethod = "POST" baseRequest.httpBody = requestPayload diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/RESTOperationRequestUtils.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/RESTOperationRequestUtils.swift index a553dcef71..a73f46eaae 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/RESTOperationRequestUtils.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/RESTOperationRequestUtils.swift @@ -52,36 +52,18 @@ final class RESTOperationRequestUtils { } // Construct a request specific to the `RESTOperationType` - static func constructURLRequest(with url: URL, - operationType: RESTOperationType, - headers: [String: String]?, - requestPayload: Data?) -> URLRequest { - - let baseHeaders = ["content-type": "application/json"] + static func constructURLRequest( + with url: URL, + operationType: RESTOperationType, + requestPayload: Data? + ) -> URLRequest { var baseRequest = URLRequest(url: url) - baseRequest = applyCustomizeRequestHeaders( - baseHeaders.merging(headers ?? [:], uniquingKeysWith: { _, new in new }), - on: baseRequest - ) + baseRequest.setValue("application/json", forHTTPHeaderField: "content-type") baseRequest.httpMethod = operationType.rawValue baseRequest.httpBody = requestPayload return baseRequest } - static func applyCustomizeRequestHeaders(_ headers: [String: String]?, on request: URLRequest) -> URLRequest { - guard let headers = headers, - let mutableRequest = (request as NSURLRequest).mutableCopy() as? NSMutableURLRequest - else { - return request - } - - for (key, value) in headers { - mutableRequest.setValue(value, forHTTPHeaderField: key) - } - - return mutableRequest as URLRequest - } - private static let permittedQueryParamCharacters = CharacterSet.alphanumerics .union(.init(charactersIn: "/_-.~")) diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Result+Async.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Result+Async.swift new file mode 100644 index 0000000000..fbf204c971 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Result+Async.swift @@ -0,0 +1,20 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +extension Result { + func flatMapAsync(_ f: (Success) async -> Result) async -> Result { + switch self { + case .success(let value): + return await f(value) + case .failure(let error): + return .failure(error) + } + } +} diff --git a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginRESTIAMTests/RESTWithIAMIntegrationTests.swift b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginRESTIAMTests/RESTWithIAMIntegrationTests.swift index c95639c7c7..b63c8d2fda 100644 --- a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginRESTIAMTests/RESTWithIAMIntegrationTests.swift +++ b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginRESTIAMTests/RESTWithIAMIntegrationTests.swift @@ -164,7 +164,6 @@ class RESTWithIAMIntegrationTests: XCTestCase { let request = RESTRequest(path: "/items", headers: ["Content-Type": "text/plain"]) do { _ = try await Amplify.API.get(request: request) - XCTFail("Should catch error") } catch { guard let apiError = error as? APIError else { XCTFail("Error should be APIError") diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+InterceptorBehaviorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+InterceptorBehaviorTests.swift index adbca6ccfb..c746d36606 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+InterceptorBehaviorTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+InterceptorBehaviorTests.swift @@ -14,12 +14,16 @@ class AWSAPICategoryPluginInterceptorBehaviorTests: AWSAPICategoryPluginTestBase func testAddInterceptor() throws { XCTAssertNotNil(apiPlugin.pluginConfig.endpoints[apiName]) - XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName).count, 0) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.amplifyInterceptors.count, 0) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.interceptors.count, 0) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.checksumInterceptors.count, 0) let provider = BasicUserPoolTokenProvider(authService: authService) let requestInterceptor = AuthTokenURLRequestInterceptor(authTokenProvider: provider) try apiPlugin.add(interceptor: requestInterceptor, for: apiName) - XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName).count, 1) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.amplifyInterceptors.count, 0) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.interceptors.count, 1) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.checksumInterceptors.count, 0) } } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPICategoryPluginConfigurationTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPICategoryPluginConfigurationTests.swift index 733065779e..b50c929011 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPICategoryPluginConfigurationTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPICategoryPluginConfigurationTests.swift @@ -58,18 +58,22 @@ class AWSAPICategoryPluginConfigurationTests: XCTestCase { func testAddInterceptors() { let apiKeyInterceptor = APIKeyURLRequestInterceptor(apiKeyProvider: BasicAPIKeyProvider(apiKey: apiKey)) config?.addInterceptor(apiKeyInterceptor, toEndpoint: graphQLAPI) - XCTAssertEqual(config?.interceptorsForEndpoint(named: graphQLAPI).count, 1) + XCTAssertEqual(config?.interceptorsForEndpoint(named: graphQLAPI)?.amplifyInterceptors.count, 0) + XCTAssertEqual(config?.interceptorsForEndpoint(named: graphQLAPI)?.interceptors.count, 1) + XCTAssertEqual(config?.interceptorsForEndpoint(named: graphQLAPI)?.checksumInterceptors.count, 0) } /// Given: multiple interceptors conforming to URLRequestInterceptor and an EndpointConfig /// When: interceptorsForEndpoint is called with the given EndpointConfig /// Then: the registered interceptors are returned - func testInterceptorsForEndpointWithConfig() throws { + func testInterceptorsForEndpointWithConfig() { let apiKeyInterceptor = APIKeyURLRequestInterceptor(apiKeyProvider: BasicAPIKeyProvider(apiKey: apiKey)) config?.addInterceptor(apiKeyInterceptor, toEndpoint: graphQLAPI) config?.addInterceptor(CustomURLInterceptor(), toEndpoint: graphQLAPI) - let interceptors = try config?.interceptorsForEndpoint(withConfig: endpointConfig!) - XCTAssertEqual(interceptors!.count, 2) + let interceptors = config?.interceptorsForEndpoint(withConfig: endpointConfig!) + XCTAssertEqual(interceptors!.amplifyInterceptors.count, 0) + XCTAssertEqual(interceptors!.interceptors.count, 2) + XCTAssertEqual(interceptors!.checksumInterceptors.count, 0) } /// Given: multiple interceptors conforming to URLRequestInterceptor @@ -84,9 +88,9 @@ class AWSAPICategoryPluginConfigurationTests: XCTestCase { let interceptors = try config?.interceptorsForEndpoint(withConfig: endpointConfig!, authType: .amazonCognitoUserPools) - XCTAssertEqual(interceptors!.count, 2) - XCTAssertNotNil(interceptors![0] as? AuthTokenURLRequestInterceptor) - XCTAssertNotNil(interceptors![1] as? CustomURLInterceptor) + XCTAssertEqual(interceptors!.amplifyInterceptors.count, 1) + XCTAssertNotNil(interceptors!.amplifyInterceptors[0] as? AuthTokenURLRequestInterceptor) + XCTAssertNotNil(interceptors!.interceptors[1] as? CustomURLInterceptor) } /// Given: an auth interceptor conforming to URLRequestInterceptor @@ -99,8 +103,8 @@ class AWSAPICategoryPluginConfigurationTests: XCTestCase { let interceptors = try config?.interceptorsForEndpoint(withConfig: endpointConfig!, authType: .apiKey) - XCTAssertEqual(interceptors!.count, 1) - XCTAssertNotNil(interceptors![0] as? APIKeyURLRequestInterceptor) + XCTAssertEqual(interceptors!.amplifyInterceptors.count, 1) + XCTAssertNotNil(interceptors!.amplifyInterceptors[0] as? APIKeyURLRequestInterceptor) } // MARK: - Helpers diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPIEndpointInterceptorsTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPIEndpointInterceptorsTests.swift index 5cab1d4a91..d9a435b236 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPIEndpointInterceptorsTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPIEndpointInterceptorsTests.swift @@ -31,8 +31,9 @@ class AWSAPIEndpointInterceptorsTests: XCTestCase { try interceptorConfig.addAuthInterceptorsToEndpoint(endpointType: .graphQL, authConfiguration: config!) - XCTAssertEqual(interceptorConfig.interceptors.count, 1) - + XCTAssertEqual(interceptorConfig.amplifyInterceptors.count, 1) + XCTAssertEqual(interceptorConfig.interceptors.count, 0) + XCTAssertEqual(interceptorConfig.checksumInterceptors.count, 0) } /// Given: an AWSAPIEndpointInterceptors @@ -44,7 +45,9 @@ class AWSAPIEndpointInterceptorsTests: XCTestCase { try interceptorConfig.addAuthInterceptorsToEndpoint(endpointType: .graphQL, authConfiguration: config!) interceptorConfig.addInterceptor(CustomInterceptor()) - XCTAssertEqual(interceptorConfig.interceptors.count, 2) + XCTAssertEqual(interceptorConfig.amplifyInterceptors.count, 1) + XCTAssertEqual(interceptorConfig.interceptors.count, 1) + XCTAssertEqual(interceptorConfig.checksumInterceptors.count, 0) } func testaddMultipleAuthInterceptors() throws { @@ -69,10 +72,12 @@ class AWSAPIEndpointInterceptorsTests: XCTestCase { try interceptorConfig.addAuthInterceptorsToEndpoint(endpointType: .graphQL, authConfiguration: userPoolConfig) - XCTAssertEqual(interceptorConfig.interceptors.count, 3) - XCTAssertNotNil(interceptorConfig.interceptors[0] as? APIKeyURLRequestInterceptor) - XCTAssertNotNil(interceptorConfig.interceptors[1] as? IAMURLRequestInterceptor) - XCTAssertNotNil(interceptorConfig.interceptors[2] as? AuthTokenURLRequestInterceptor) + XCTAssertEqual(interceptorConfig.amplifyInterceptors.count, 2) + XCTAssertEqual(interceptorConfig.interceptors.count, 0) + XCTAssertEqual(interceptorConfig.checksumInterceptors.count, 1) + XCTAssertNotNil(interceptorConfig.amplifyInterceptors[0] as? APIKeyURLRequestInterceptor) + XCTAssertNotNil(interceptorConfig.amplifyInterceptors[1] as? AuthTokenURLRequestInterceptor) + XCTAssertNotNil(interceptorConfig.checksumInterceptors[0] as? IAMURLRequestInterceptor) } // MARK: - Test Helpers diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/AuthTokenURLRequestInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/AuthTokenURLRequestInterceptorTests.swift index ed1df2905b..25370ac15b 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/AuthTokenURLRequestInterceptorTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/AuthTokenURLRequestInterceptorTests.swift @@ -15,7 +15,11 @@ class AuthTokenURLRequestInterceptorTests: XCTestCase { func testAuthTokenInterceptor() async throws { let mockTokenProvider = MockTokenProvider() let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: mockTokenProvider) - let request = URLRequest(url: URL(string: "http://anapiendpoint.ca")!) + let request = RESTOperationRequestUtils.constructURLRequest( + with: URL(string: "http://anapiendpoint.ca")!, + operationType: .get, + requestPayload: nil + ) guard let headers = try await interceptor.intercept(request).allHTTPHeaderFields else { XCTFail("Failed retrieving headers") diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/RESTRequestUtilsTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/RESTRequestUtilsTests.swift index 4de8148fcc..7d3af18917 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/RESTRequestUtilsTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/RESTRequestUtilsTests.swift @@ -70,7 +70,6 @@ class RESTRequestUtilsTests: XCTestCase { let urlRequest = RESTOperationRequestUtils.constructURLRequest( with: url, operationType: .get, - headers: nil, requestPayload: nil ) @@ -95,18 +94,6 @@ class RESTRequestUtilsTests: XCTestCase { ) ) } - - func testApplyCustomizeRequestHeaders_withCutomeHeaders_successfullyOverride() { - var request = URLRequest(url: URL(string: "https://aws.amazon.com")!) - request.allHTTPHeaderFields = ["Content-Type": "application/json"] - let headers = ["Content-Type": "text/plain"] - let requestWithHeaders = RESTOperationRequestUtils.applyCustomizeRequestHeaders(headers, on: request) - XCTAssertNotNil(requestWithHeaders.allHTTPHeaderFields) - for (key, value) in headers { - XCTAssertEqual(requestWithHeaders.allHTTPHeaderFields![key], value) - } - } - } extension RESTRequestUtilsTests { diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Result+AsyncTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Result+AsyncTests.swift new file mode 100644 index 0000000000..148d630fff --- /dev/null +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Result+AsyncTests.swift @@ -0,0 +1,53 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +@testable import AWSAPIPlugin + +class ResultAsyncTests: XCTestCase { + + func testFlatMapAsync_withSuccess_applyFunction() async { + func plus1(_ number: Int) async -> Int { + return number + 1 + } + + let result = Result.success(0) + let plus1Result = await result.flatMapAsync { + .success(await plus1($0)) + } + + switch plus1Result { + case .success(let plus1Result): + XCTAssertEqual(plus1Result, 1) + case .failure(let error): + XCTFail("Failed with error \(error)") + } + } + + func testFlatMapAsync_withFailure_notApplyFunction() async { + func arrayCount(_ array: [Int]) async -> Int { + return array.count + } + + + let expectedError = TestError() + let result = Result<[Int], Error>.failure(expectedError) + let count = await result.flatMapAsync { + .success(await arrayCount($0)) + } + + switch count { + case .success: + XCTFail("Should fail") + case .failure(let error): + XCTAssertTrue(error is TestError) + } + } +} + +fileprivate class TestError: Error { }