From 6f95db94d766921c92fef5631cb9b867899bdeca Mon Sep 17 00:00:00 2001 From: Di Wu Date: Tue, 12 Sep 2023 12:56:51 -0700 Subject: [PATCH] fix(api): change request interceptors applying logic (#3190) * fix(api): apply cutomize request headers after interceptors * add integration test case * change interceptors applying logic * refactor code style * refactor code style for rest operations * add unit test case for customer header override * update interceptor names * update interceptor doc comment --- .../AWSAPICategoryPluginConfiguration.swift | 34 ++-- .../AWSAPIEndpointInterceptors.swift | 22 ++- .../AuthTokenURLRequestInterceptor.swift | 4 +- .../IAMURLRequestInterceptor.swift | 1 - .../Operation/AWSGraphQLOperation.swift | 182 ++++++++++++------ .../Operation/AWSRESTOperation.swift | 169 +++++++++------- .../Utils/GraphQLOperationRequestUtils.swift | 4 +- .../Utils/RESTOperationRequestUtils.swift | 18 +- .../Support/Utils/Result+Async.swift | 20 ++ .../GraphQLConnectionScenario4Tests.swift | 2 +- .../GraphQLModelBasedTests+List.swift | 2 +- .../RESTWithIAMIntegrationTests.swift | 18 ++ ...egoryPlugin+InterceptorBehaviorTests.swift | 8 +- ...SAPICategoryPluginConfigurationTests.swift | 22 ++- .../AWSAPIEndpointInterceptorsTests.swift | 19 +- .../AuthTokenURLRequestInterceptorTests.swift | 6 +- .../Operation/AWSRESTOperationTests.swift | 36 ++++ .../Support/Utils/RESTRequestUtilsTests.swift | 1 - .../Support/Utils/Result+AsyncTests.swift | 54 ++++++ 19 files changed, 430 insertions(+), 192 deletions(-) create mode 100644 AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Result+Async.swift create mode 100644 AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Result+AsyncTests.swift diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPICategoryPluginConfiguration.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPICategoryPluginConfiguration.swift index 9a4d894c41..200a813e96 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPICategoryPluginConfiguration.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPICategoryPluginConfiguration.swift @@ -72,7 +72,7 @@ public struct AWSAPICategoryPluginConfiguration { self.authService = authService } - /// Registers an interceptor for the provided API endpoint + /// Registers an customer interceptor for the provided API endpoint /// - Parameter interceptor: operation interceptor used to decorate API requests /// - Parameter toEndpoint: API endpoint name mutating func addInterceptor(_ interceptor: URLRequestInterceptor, @@ -86,20 +86,16 @@ public struct AWSAPICategoryPluginConfiguration { /// Returns all the interceptors registered for `apiName` API endpoint /// - Parameter apiName: API endpoint name - /// - Returns: request interceptors - internal func interceptorsForEndpoint(named apiName: APIEndpointName) -> [URLRequestInterceptor] { - guard let interceptorsConfig = interceptors[apiName] else { - return [] - } - return interceptorsConfig.interceptors + /// - Returns: Optional AWSAPIEndpointInterceptors for the apiName + internal func interceptorsForEndpoint(named apiName: APIEndpointName) -> AWSAPIEndpointInterceptors? { + return interceptors[apiName] } - /// Returns interceptors for the provided endpointConfig + /// Returns the interceptors for the provided endpointConfig /// - Parameters: /// - endpointConfig: endpoint configuration - /// - Throws: PluginConfigurationError in case of failure building an instance of AWSAuthorizationConfiguration - /// - Returns: An array of URLRequestInterceptor - internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig) throws -> [URLRequestInterceptor] { + /// - Returns: Optional AWSAPIEndpointInterceptors for the endpointConfig + internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig) -> AWSAPIEndpointInterceptors? { return interceptorsForEndpoint(named: endpointConfig.name) } @@ -108,9 +104,11 @@ public struct AWSAPICategoryPluginConfiguration { /// - endpointConfig: endpoint configuration /// - authType: overrides the registered auth interceptor /// - Throws: PluginConfigurationError in case of failure building an instance of AWSAuthorizationConfiguration - /// - Returns: An array of URLRequestInterceptor - internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig, - authType: AWSAuthorizationType) throws -> [URLRequestInterceptor] { + /// - Returns: Optional AWSAPIEndpointInterceptors for the endpointConfig and authType + internal func interceptorsForEndpoint( + withConfig endpointConfig: EndpointConfig, + authType: AWSAuthorizationType + ) throws -> AWSAPIEndpointInterceptors? { guard let apiAuthProviderFactory = self.apiAuthProviderFactory else { return interceptorsForEndpoint(named: endpointConfig.name) @@ -126,12 +124,10 @@ public struct AWSAPICategoryPluginConfiguration { authConfiguration: authConfiguration) // retrieve current interceptors and replace auth interceptor - let currentInterceptors = interceptorsForEndpoint(named: endpointConfig.name).filter { - !isAuthInterceptor($0) - } - config.interceptors.append(contentsOf: currentInterceptors) + let currentInterceptors = interceptorsForEndpoint(named: endpointConfig.name) + config.interceptors.append(contentsOf: currentInterceptors?.interceptors ?? []) - return config.interceptors + return config } // MARK: Private diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPIEndpointInterceptors.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPIEndpointInterceptors.swift index 13f6ba16e8..9be9d3c05b 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPIEndpointInterceptors.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Configuration/AWSAPIEndpointInterceptors.swift @@ -9,6 +9,14 @@ import Amplify import Foundation import AWSPluginsCore +/// The order of interceptor decoration is as follows: +/// 1. **prelude interceptors** +/// 2. **cutomize headers** +/// 3. **customer interceptors** +/// 4. **postlude interceptors** +/// +/// **Prelude** and **postlude** interceptors are used by library maintainers to +/// integrate essential functionality for a variety of authentication types. struct AWSAPIEndpointInterceptors { // API name let apiEndpointName: APIEndpointName @@ -16,8 +24,12 @@ struct AWSAPIEndpointInterceptors { let apiAuthProviderFactory: APIAuthProviderFactory let authService: AWSAuthServiceBehavior? + var preludeInterceptors: [URLRequestInterceptor] = [] + var interceptors: [URLRequestInterceptor] = [] + var postludeInterceptors: [URLRequestInterceptor] = [] + init(endpointName: APIEndpointName, apiAuthProviderFactory: APIAuthProviderFactory, authService: AWSAuthServiceBehavior? = nil) { @@ -42,7 +54,7 @@ struct AWSAPIEndpointInterceptors { case .apiKey(let apiKeyConfig): let provider = BasicAPIKeyProvider(apiKey: apiKeyConfig.apiKey) let interceptor = APIKeyURLRequestInterceptor(apiKeyProvider: provider) - addInterceptor(interceptor) + preludeInterceptors.append(interceptor) case .awsIAM(let iamConfig): guard let authService = authService else { throw PluginError.pluginConfigurationError("AuthService is not set for IAM", @@ -52,7 +64,7 @@ struct AWSAPIEndpointInterceptors { let interceptor = IAMURLRequestInterceptor(iamCredentialsProvider: provider, region: iamConfig.region, endpointType: endpointType) - addInterceptor(interceptor) + postludeInterceptors.append(interceptor) case .amazonCognitoUserPools: guard let authService = authService else { throw PluginError.pluginConfigurationError("AuthService not set for cognito user pools", @@ -60,7 +72,7 @@ struct AWSAPIEndpointInterceptors { } let provider = BasicUserPoolTokenProvider(authService: authService) let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: provider) - addInterceptor(interceptor) + preludeInterceptors.append(interceptor) case .openIDConnect: guard let oidcAuthProvider = apiAuthProviderFactory.oidcAuthProvider() else { throw PluginError.pluginConfigurationError("AuthService not set for OIDC", @@ -68,7 +80,7 @@ struct AWSAPIEndpointInterceptors { } let wrappedAuthProvider = AuthTokenProviderWrapper(tokenAuthProvider: oidcAuthProvider) let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: wrappedAuthProvider) - addInterceptor(interceptor) + preludeInterceptors.append(interceptor) case .function: guard let functionAuthProvider = apiAuthProviderFactory.functionAuthProvider() else { throw PluginError.pluginConfigurationError("AuthService not set for function auth", @@ -76,7 +88,7 @@ struct AWSAPIEndpointInterceptors { } let wrappedAuthProvider = AuthTokenProviderWrapper(tokenAuthProvider: functionAuthProvider) let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: wrappedAuthProvider) - addInterceptor(interceptor) + preludeInterceptors.append(interceptor) } } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/AuthTokenURLRequestInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/AuthTokenURLRequestInterceptor.swift index 0175364b7a..d7392748e8 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/AuthTokenURLRequestInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/AuthTokenURLRequestInterceptor.swift @@ -32,9 +32,7 @@ struct AuthTokenURLRequestInterceptor: URLRequestInterceptor { mutableRequest.setValue(amzDate, forHTTPHeaderField: URLRequestConstants.Header.xAmzDate) - mutableRequest.setValue(URLRequestConstants.ContentType.applicationJson, - forHTTPHeaderField: URLRequestConstants.Header.contentType) - mutableRequest.setValue(userAgent, + mutableRequest.addValue(userAgent, forHTTPHeaderField: URLRequestConstants.Header.userAgent) let token: String diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/IAMURLRequestInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/IAMURLRequestInterceptor.swift index 013cac0459..0a9b53fe93 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/IAMURLRequestInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/RequestInterceptor/IAMURLRequestInterceptor.swift @@ -36,7 +36,6 @@ struct IAMURLRequestInterceptor: URLRequestInterceptor { throw APIError.unknown("Could not get host from mutable request", "") } - request.setValue(URLRequestConstants.ContentType.applicationJson, forHTTPHeaderField: URLRequestConstants.Header.contentType) request.setValue(host, forHTTPHeaderField: "host") request.setValue(userAgent, forHTTPHeaderField: URLRequestConstants.Header.userAgent) diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift index 2f6ed4ba27..859981e321 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLOperation.swift @@ -34,6 +34,10 @@ final public class AWSGraphQLOperation: GraphQLOperation { } override public func main() { + Task { await mainAsync() } + } + + private func mainAsync() async { Amplify.API.log.debug("Starting \(request.operationType) \(id)") if isCancelled { @@ -41,43 +45,79 @@ final public class AWSGraphQLOperation: GraphQLOperation { return } - // Validate the request - do { - try request.validate() - } catch let error as APIError { + let urlRequest = validateRequest(request).flatMap(buildURLRequest(from:)) + let finalRequest = await getEndpointInterceptors(from: request).flatMapAsync { requestInterceptors in + let preludeInterceptors = requestInterceptors?.preludeInterceptors ?? [] + let customerInterceptors = requestInterceptors?.interceptors ?? [] + let postludeInterceptors = requestInterceptors?.postludeInterceptors ?? [] + + var finalResult = urlRequest + // apply prelude interceptors + for interceptor in preludeInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) + } + } + + // there is no customize headers for GraphQLOperationRequest + + // apply customer interceptors + for interceptor in customerInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) + } + } + + // apply postlude interceptor + for interceptor in postludeInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) + } + } + return finalResult + } + + switch finalRequest { + case .success(let finalRequest): + if isCancelled { + finish() + return + } + + // Begin network task + Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)") + let task = session.dataTaskBehavior(with: finalRequest) + mapper.addPair(operation: self, task: task) + task.resume() + case .failure(let error): dispatch(result: .failure(error)) finish() - return - } catch { - dispatch(result: .failure(APIError.unknown("Could not validate request", "", nil))) - finish() - return } + } - // Retrieve endpoint configuration - let endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig - let requestInterceptors: [URLRequestInterceptor] - + private func validateRequest(_ request: GraphQLOperationRequest) -> Result, APIError> { do { - endpointConfig = try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .graphQL) - - if let pluginOptions = request.options.pluginOptions as? AWSPluginOptions, - let authType = pluginOptions.authType { - requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig, - authType: authType) - } else { - requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig) - } + try request.validate() + return .success(request) } catch let error as APIError { - dispatch(result: .failure(error)) - finish() - return + return .failure(error) } catch { - dispatch(result: .failure(APIError.unknown("Could not get endpoint configuration", "", nil))) - finish() - return + return .failure(APIError.unknown("Could not validate request", "", nil)) } + } + private func buildURLRequest(from request: GraphQLOperationRequest) -> Result { + getEndpointConfig(from: request).flatMap { endpointConfig in + getRequestPayload(from: request).map { requestPayload in + GraphQLOperationRequestUtils.constructRequest( + with: endpointConfig.baseURL, + requestPayload: requestPayload + ) + } + } + } + + private func getRequestPayload(from request: GraphQLOperationRequest) -> Result { // Prepare request payload let queryDocument = GraphQLOperationRequestUtils.getQueryDocument(document: request.document, variables: request.variables) @@ -87,48 +127,64 @@ final public class AWSGraphQLOperation: GraphQLOperation { let prettyPrintedQueryDocument = String(data: serializedJSON, encoding: .utf8) { Amplify.API.log.verbose("\(prettyPrintedQueryDocument)") } - let requestPayload: Data + do { - requestPayload = try JSONSerialization.data(withJSONObject: queryDocument) + return .success(try JSONSerialization.data(withJSONObject: queryDocument)) } catch { - dispatch(result: .failure(APIError.operationError("Failed to serialize query document", - "fix the document or variables", - error))) - finish() - return + return .failure(APIError.operationError( + "Failed to serialize query document", + "fix the document or variables", + error + )) } + } - // Create request - let urlRequest = GraphQLOperationRequestUtils.constructRequest(with: endpointConfig.baseURL, - requestPayload: requestPayload) - - Task { - // Intercept request - var finalRequest = urlRequest - for interceptor in requestInterceptors { - do { - finalRequest = try await interceptor.intercept(finalRequest) - } catch let error as APIError { - dispatch(result: .failure(error)) - cancel() - } catch { - dispatch(result: .failure(APIError.operationError("Failed to intercept request fully.", - "Something wrong with the interceptor", - error))) - cancel() - } - } + private func getEndpointConfig(from request: GraphQLOperationRequest) -> Result { + do { + return .success(try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .graphQL)) + } catch let error as APIError { + return .failure(error) - if isCancelled { - finish() - return + } catch { + return .failure(APIError.unknown("Could not get endpoint configuration", "", nil)) + } + } + + private func getEndpointInterceptors(from request: GraphQLOperationRequest) -> Result { + getEndpointConfig(from: request).flatMap { endpointConfig in + do { + if let pluginOptions = request.options.pluginOptions as? AWSPluginOptions, + let authType = pluginOptions.authType + { + return .success(try pluginConfig.interceptorsForEndpoint( + withConfig: endpointConfig, + authType: authType + )) + } else { + return .success(pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig)) + } + } catch let error as APIError { + return .failure(error) + } catch { + return .failure(APIError.unknown("Could not get endpoint interceptors", "", nil)) } + } + } - // Begin network task - Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)") - let task = session.dataTaskBehavior(with: finalRequest) - mapper.addPair(operation: self, task: task) - task.resume() + private func applyInterceptor(_ interceptor: URLRequestInterceptor, request: URLRequest) async -> Result { + do { + return .success(try await interceptor.intercept(request)) + } catch let error as APIError { + return .failure(error) + } catch { + return .failure( + APIError.operationError( + "Failed to intercept request with \(type(of: interceptor)). Error message: \(error.localizedDescription).", + "See underlying error for more details", + error + ) + ) } } + } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSRESTOperation.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSRESTOperation.swift index 1c8efc986d..3680b22728 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSRESTOperation.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSRESTOperation.swift @@ -40,82 +40,57 @@ final public class AWSRESTOperation: AmplifyOperation< /// The work to execute for this operation override public func main() { + Task { await mainAsync() } + } + + private func mainAsync() async { if isCancelled { finish() return } - // Validate the request - do { - try request.validate() - } catch let error as APIError { - dispatch(result: .failure(error)) - finish() - return - } catch { - dispatch(result: .failure(APIError.unknown("Could not validate request", "", nil))) - finish() - return - } + let urlRequest = validateRequest(request).flatMap(buildURLRequest(from:)) + let finalRequest = await getEndpointConfig(from: request).flatMapAsync { endpointConfig in + let interceptorConfig = pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig) + let preludeInterceptors = interceptorConfig?.preludeInterceptors ?? [] + let customerInterceptors = interceptorConfig?.interceptors ?? [] + let postludeInterceptors = interceptorConfig?.postludeInterceptors ?? [] + + var finalResult = urlRequest + // apply prelude interceptors + for interceptor in preludeInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) + } + } - // Retrieve endpoint configuration - let endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig - let requestInterceptors: [URLRequestInterceptor] - do { - endpointConfig = try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .rest) - requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig) - } catch let error as APIError { - dispatch(result: .failure(error)) - finish() - return - } catch { - dispatch(result: .failure(APIError.unknown("Could not get endpoint configuration", "", nil))) - finish() - return - } + // apply customize headers + finalResult = finalResult.map { urlRequest in + var mutableRequest = urlRequest + for (key, value) in request.headers ?? [:] { + mutableRequest.setValue(value, forHTTPHeaderField: key) + } + return mutableRequest + } - // Construct URL with path - let url: URL - do { - url = try RESTOperationRequestUtils.constructURL( - for: endpointConfig.baseURL, - withPath: request.path, - withParams: request.queryParameters - ) - } catch let error as APIError { - dispatch(result: .failure(error)) - finish() - return - } catch { - let apiError = APIError.operationError("Failed to construct URL", "", error) - dispatch(result: .failure(apiError)) - finish() - return - } + // apply customer interceptors + for interceptor in customerInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) + } + } - // Construct URL Request with url and request body - let urlRequest = RESTOperationRequestUtils.constructURLRequest(with: url, - operationType: request.operationType, - headers: request.headers, - requestPayload: request.body) - - Task { - // Intercept request - var finalRequest = urlRequest - for interceptor in requestInterceptors { - do { - finalRequest = try await interceptor.intercept(finalRequest) - } catch let error as APIError { - dispatch(result: .failure(error)) - cancel() - } catch { - dispatch(result: .failure(APIError.operationError("Failed to intercept request fully.", - "Something wrong with the interceptor", - error))) - cancel() + // apply postlude interceptor + for interceptor in postludeInterceptors { + finalResult = await finalResult.flatMapAsync { request in + await applyInterceptor(interceptor, request: request) } } + return finalResult + } + switch finalRequest { + case .success(let finalRequest): if isCancelled { finish() return @@ -126,6 +101,70 @@ final public class AWSRESTOperation: AmplifyOperation< let task = session.dataTaskBehavior(with: finalRequest) mapper.addPair(operation: self, task: task) task.resume() + case .failure(let error): + Amplify.API.log.debug("Dispatching error \(error)") + dispatch(result: .failure(error)) + finish() + } + } + + private func validateRequest(_ request: RESTOperationRequest) -> Result { + do { + try request.validate() + return .success(request) + } catch let error as APIError { + return .failure(error) + } catch { + return .failure(APIError.unknown("Could not validate request", "", nil)) + } + } + + private func getEndpointConfig( + from request: RESTOperationRequest + ) -> Result { + do { + return .success(try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .rest)) + } catch let error as APIError { + return .failure(error) + } catch { + return .failure(APIError.unknown("Could not get endpoint configuration", "", nil)) + } + } + + private func buildURLRequest(from request: RESTOperationRequest) -> Result { + getEndpointConfig(from: request).flatMap { endpointConfig in + do { + let url = try RESTOperationRequestUtils.constructURL( + for: endpointConfig.baseURL, + withPath: request.path, + withParams: request.queryParameters + ) + return .success(RESTOperationRequestUtils.constructURLRequest( + with: url, + operationType: request.operationType, + requestPayload: request.body + )) + } catch let error as APIError { + return .failure(error) + } catch { + return .failure(APIError.operationError("Failed to construct URL", "", error)) + } + } + } + + private func applyInterceptor(_ interceptor: URLRequestInterceptor, request: URLRequest) async -> Result { + do { + return .success(try await interceptor.intercept(request)) + } catch let error as APIError { + return .failure(error) + } catch { + return .failure( + APIError.operationError( + "Failed to intercept request with \(type(of: interceptor)). Error message: \(error.localizedDescription).", + "See underlying error for more details", + error + ) + ) } } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLOperationRequestUtils.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLOperationRequestUtils.swift index 18f695e68b..80e7bc74df 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLOperationRequestUtils.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/GraphQLOperationRequestUtils.swift @@ -22,8 +22,8 @@ class GraphQLOperationRequestUtils { // Construct a graphQL specific HTTP POST request with the request payload static func constructRequest(with baseUrl: URL, requestPayload: Data) -> URLRequest { var baseRequest = URLRequest(url: baseUrl) - let headers = ["content-type": "application/json", "Cache-Control": "no-store"] - baseRequest.allHTTPHeaderFields = headers + baseRequest.setValue("application/json", forHTTPHeaderField: "content-type") + baseRequest.setValue("no-store", forHTTPHeaderField: "cache-control") baseRequest.httpMethod = "POST" baseRequest.httpBody = requestPayload diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/RESTOperationRequestUtils.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/RESTOperationRequestUtils.swift index 9970ea0b4b..a73f46eaae 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/RESTOperationRequestUtils.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/RESTOperationRequestUtils.swift @@ -52,19 +52,13 @@ final class RESTOperationRequestUtils { } // Construct a request specific to the `RESTOperationType` - static func constructURLRequest(with url: URL, - operationType: RESTOperationType, - headers: [String: String]?, - requestPayload: Data?) -> URLRequest { - + static func constructURLRequest( + with url: URL, + operationType: RESTOperationType, + requestPayload: Data? + ) -> URLRequest { var baseRequest = URLRequest(url: url) - var requestHeaders = ["content-type": "application/json"] - if let headers = headers { - for (key, value) in headers { - requestHeaders[key] = value - } - } - baseRequest.allHTTPHeaderFields = requestHeaders + baseRequest.setValue("application/json", forHTTPHeaderField: "content-type") baseRequest.httpMethod = operationType.rawValue baseRequest.httpBody = requestPayload return baseRequest diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Result+Async.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Result+Async.swift new file mode 100644 index 0000000000..fbf204c971 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Result+Async.swift @@ -0,0 +1,20 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +extension Result { + func flatMapAsync(_ f: (Success) async -> Result) async -> Result { + switch self { + case .success(let value): + return await f(value) + case .failure(let error): + return .failure(error) + } + } +} diff --git a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLConnectionScenario4Tests.swift b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLConnectionScenario4Tests.swift index 943f1911d9..cb4c3ac7a5 100644 --- a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLConnectionScenario4Tests.swift +++ b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLConnectionScenario4Tests.swift @@ -213,7 +213,7 @@ class GraphQLConnectionScenario4Tests: XCTestCase { } let predicate = field("postID").eq(post.id) var results: List? - let result = try await Amplify.API.query(request: .list(Comment4.self, where: predicate, limit: 1)) + let result = try await Amplify.API.query(request: .list(Comment4.self, where: predicate, limit: 3000)) switch result { case .success(let comments): results = comments diff --git a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLModelBasedTests+List.swift b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLModelBasedTests+List.swift index bf38826118..c42916e6e6 100644 --- a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLModelBasedTests+List.swift +++ b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLModelBasedTests+List.swift @@ -39,7 +39,7 @@ extension GraphQLModelBasedTests { let post = Post.keys let predicate = post.id == uuid1 || post.id == uuid2 var results: List? - let response = try await Amplify.API.query(request: .list(Post.self, where: predicate, limit: 1)) + let response = try await Amplify.API.query(request: .list(Post.self, where: predicate, limit: 3000)) guard case .success(let graphQLresponse) = response else { XCTFail("Missing successful response") diff --git a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginRESTIAMTests/RESTWithIAMIntegrationTests.swift b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginRESTIAMTests/RESTWithIAMIntegrationTests.swift index 8f96b69425..b63c8d2fda 100644 --- a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginRESTIAMTests/RESTWithIAMIntegrationTests.swift +++ b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginRESTIAMTests/RESTWithIAMIntegrationTests.swift @@ -159,6 +159,24 @@ class RESTWithIAMIntegrationTests: XCTestCase { XCTAssertEqual(statusCode, 404) } } + + func testRestRequest_withCustomizeHeaders_succefullyOverride() async throws { + let request = RESTRequest(path: "/items", headers: ["Content-Type": "text/plain"]) + do { + _ = try await Amplify.API.get(request: request) + } catch { + guard let apiError = error as? APIError else { + XCTFail("Error should be APIError") + return + } + guard case let .httpStatusError(statusCode, _) = apiError else { + XCTFail("Error should be httpStatusError") + return + } + + XCTAssertEqual(statusCode, 403) + } + } } extension RESTWithIAMIntegrationTests: DefaultLogger { } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+InterceptorBehaviorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+InterceptorBehaviorTests.swift index adbca6ccfb..2974af9f4d 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+InterceptorBehaviorTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+InterceptorBehaviorTests.swift @@ -14,12 +14,16 @@ class AWSAPICategoryPluginInterceptorBehaviorTests: AWSAPICategoryPluginTestBase func testAddInterceptor() throws { XCTAssertNotNil(apiPlugin.pluginConfig.endpoints[apiName]) - XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName).count, 0) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.preludeInterceptors.count, 0) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.interceptors.count, 0) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.postludeInterceptors.count, 0) let provider = BasicUserPoolTokenProvider(authService: authService) let requestInterceptor = AuthTokenURLRequestInterceptor(authTokenProvider: provider) try apiPlugin.add(interceptor: requestInterceptor, for: apiName) - XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName).count, 1) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.preludeInterceptors.count, 0) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.interceptors.count, 1) + XCTAssertEqual(apiPlugin.pluginConfig.interceptorsForEndpoint(named: apiName)?.postludeInterceptors.count, 0) } } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPICategoryPluginConfigurationTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPICategoryPluginConfigurationTests.swift index 733065779e..a2f7503f01 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPICategoryPluginConfigurationTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPICategoryPluginConfigurationTests.swift @@ -58,18 +58,22 @@ class AWSAPICategoryPluginConfigurationTests: XCTestCase { func testAddInterceptors() { let apiKeyInterceptor = APIKeyURLRequestInterceptor(apiKeyProvider: BasicAPIKeyProvider(apiKey: apiKey)) config?.addInterceptor(apiKeyInterceptor, toEndpoint: graphQLAPI) - XCTAssertEqual(config?.interceptorsForEndpoint(named: graphQLAPI).count, 1) + XCTAssertEqual(config?.interceptorsForEndpoint(named: graphQLAPI)?.preludeInterceptors.count, 0) + XCTAssertEqual(config?.interceptorsForEndpoint(named: graphQLAPI)?.interceptors.count, 1) + XCTAssertEqual(config?.interceptorsForEndpoint(named: graphQLAPI)?.postludeInterceptors.count, 0) } /// Given: multiple interceptors conforming to URLRequestInterceptor and an EndpointConfig /// When: interceptorsForEndpoint is called with the given EndpointConfig /// Then: the registered interceptors are returned - func testInterceptorsForEndpointWithConfig() throws { + func testInterceptorsForEndpointWithConfig() { let apiKeyInterceptor = APIKeyURLRequestInterceptor(apiKeyProvider: BasicAPIKeyProvider(apiKey: apiKey)) config?.addInterceptor(apiKeyInterceptor, toEndpoint: graphQLAPI) config?.addInterceptor(CustomURLInterceptor(), toEndpoint: graphQLAPI) - let interceptors = try config?.interceptorsForEndpoint(withConfig: endpointConfig!) - XCTAssertEqual(interceptors!.count, 2) + let interceptors = config?.interceptorsForEndpoint(withConfig: endpointConfig!) + XCTAssertEqual(interceptors!.preludeInterceptors.count, 0) + XCTAssertEqual(interceptors!.interceptors.count, 2) + XCTAssertEqual(interceptors!.postludeInterceptors.count, 0) } /// Given: multiple interceptors conforming to URLRequestInterceptor @@ -84,9 +88,9 @@ class AWSAPICategoryPluginConfigurationTests: XCTestCase { let interceptors = try config?.interceptorsForEndpoint(withConfig: endpointConfig!, authType: .amazonCognitoUserPools) - XCTAssertEqual(interceptors!.count, 2) - XCTAssertNotNil(interceptors![0] as? AuthTokenURLRequestInterceptor) - XCTAssertNotNil(interceptors![1] as? CustomURLInterceptor) + XCTAssertEqual(interceptors!.preludeInterceptors.count, 1) + XCTAssertNotNil(interceptors!.preludeInterceptors[0] as? AuthTokenURLRequestInterceptor) + XCTAssertNotNil(interceptors!.interceptors[1] as? CustomURLInterceptor) } /// Given: an auth interceptor conforming to URLRequestInterceptor @@ -99,8 +103,8 @@ class AWSAPICategoryPluginConfigurationTests: XCTestCase { let interceptors = try config?.interceptorsForEndpoint(withConfig: endpointConfig!, authType: .apiKey) - XCTAssertEqual(interceptors!.count, 1) - XCTAssertNotNil(interceptors![0] as? APIKeyURLRequestInterceptor) + XCTAssertEqual(interceptors!.preludeInterceptors.count, 1) + XCTAssertNotNil(interceptors!.preludeInterceptors[0] as? APIKeyURLRequestInterceptor) } // MARK: - Helpers diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPIEndpointInterceptorsTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPIEndpointInterceptorsTests.swift index 5cab1d4a91..64443e9ffc 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPIEndpointInterceptorsTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Configuration/AWSAPIEndpointInterceptorsTests.swift @@ -31,8 +31,9 @@ class AWSAPIEndpointInterceptorsTests: XCTestCase { try interceptorConfig.addAuthInterceptorsToEndpoint(endpointType: .graphQL, authConfiguration: config!) - XCTAssertEqual(interceptorConfig.interceptors.count, 1) - + XCTAssertEqual(interceptorConfig.preludeInterceptors.count, 1) + XCTAssertEqual(interceptorConfig.interceptors.count, 0) + XCTAssertEqual(interceptorConfig.postludeInterceptors.count, 0) } /// Given: an AWSAPIEndpointInterceptors @@ -44,7 +45,9 @@ class AWSAPIEndpointInterceptorsTests: XCTestCase { try interceptorConfig.addAuthInterceptorsToEndpoint(endpointType: .graphQL, authConfiguration: config!) interceptorConfig.addInterceptor(CustomInterceptor()) - XCTAssertEqual(interceptorConfig.interceptors.count, 2) + XCTAssertEqual(interceptorConfig.preludeInterceptors.count, 1) + XCTAssertEqual(interceptorConfig.interceptors.count, 1) + XCTAssertEqual(interceptorConfig.postludeInterceptors.count, 0) } func testaddMultipleAuthInterceptors() throws { @@ -69,10 +72,12 @@ class AWSAPIEndpointInterceptorsTests: XCTestCase { try interceptorConfig.addAuthInterceptorsToEndpoint(endpointType: .graphQL, authConfiguration: userPoolConfig) - XCTAssertEqual(interceptorConfig.interceptors.count, 3) - XCTAssertNotNil(interceptorConfig.interceptors[0] as? APIKeyURLRequestInterceptor) - XCTAssertNotNil(interceptorConfig.interceptors[1] as? IAMURLRequestInterceptor) - XCTAssertNotNil(interceptorConfig.interceptors[2] as? AuthTokenURLRequestInterceptor) + XCTAssertEqual(interceptorConfig.preludeInterceptors.count, 2) + XCTAssertEqual(interceptorConfig.interceptors.count, 0) + XCTAssertEqual(interceptorConfig.postludeInterceptors.count, 1) + XCTAssertNotNil(interceptorConfig.preludeInterceptors[0] as? APIKeyURLRequestInterceptor) + XCTAssertNotNil(interceptorConfig.preludeInterceptors[1] as? AuthTokenURLRequestInterceptor) + XCTAssertNotNil(interceptorConfig.postludeInterceptors[0] as? IAMURLRequestInterceptor) } // MARK: - Test Helpers diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/AuthTokenURLRequestInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/AuthTokenURLRequestInterceptorTests.swift index ed1df2905b..25370ac15b 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/AuthTokenURLRequestInterceptorTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/AuthTokenURLRequestInterceptorTests.swift @@ -15,7 +15,11 @@ class AuthTokenURLRequestInterceptorTests: XCTestCase { func testAuthTokenInterceptor() async throws { let mockTokenProvider = MockTokenProvider() let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: mockTokenProvider) - let request = URLRequest(url: URL(string: "http://anapiendpoint.ca")!) + let request = RESTOperationRequestUtils.constructURLRequest( + with: URL(string: "http://anapiendpoint.ca")!, + operationType: .get, + requestPayload: nil + ) guard let headers = try await interceptor.intercept(request).allHTTPHeaderFields else { XCTFail("Failed retrieving headers") diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSRESTOperationTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSRESTOperationTests.swift index 6162d63898..c65f1257ac 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSRESTOperationTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSRESTOperationTests.swift @@ -96,4 +96,40 @@ class AWSRESTOperationTests: OperationTestBase { wait(for: [callbackInvoked], timeout: 1.0) } + func testRESTOperation_withCustomHeader_shouldOverrideDefaultAmplifyHeaders() throws { + let expectedHeaderValue = "text/plain" + let sentData = Data([0x00, 0x01, 0x02, 0x03]) + try setUpPluginForSingleResponse(sending: sentData, for: .rest) + + let validated = expectation(description: "Header override is validated") + try apiPlugin.add(interceptor: TestURLRequestInterceptor(validate: { request in + defer { validated.fulfill() } + return request.allHTTPHeaderFields?["Content-Type"] == expectedHeaderValue + }), for: "Valid") + + let callbackInvoked = expectation(description: "Callback was invoked") + let request = RESTRequest(apiName: "Valid", path: "/path", headers: ["Content-Type": expectedHeaderValue]) + _ = apiPlugin.get(request: request) { event in + switch event { + case .success(let data): + XCTAssertEqual(data, sentData) + case .failure(let error): + XCTFail("Unexpected failure: \(error)") + } + callbackInvoked.fulfill() + } + wait(for: [callbackInvoked, validated], timeout: 1.0) + } + +} + +fileprivate struct TestURLRequestInterceptor: URLRequestInterceptor { + let validate: (URLRequest) -> Bool + + func intercept(_ request: URLRequest) async throws -> URLRequest { + XCTAssertTrue(validate(request)) + return request + } + + } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/RESTRequestUtilsTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/RESTRequestUtilsTests.swift index 8e1e5ec1ef..7d3af18917 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/RESTRequestUtilsTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/RESTRequestUtilsTests.swift @@ -70,7 +70,6 @@ class RESTRequestUtilsTests: XCTestCase { let urlRequest = RESTOperationRequestUtils.constructURLRequest( with: url, operationType: .get, - headers: nil, requestPayload: nil ) diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Result+AsyncTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Result+AsyncTests.swift new file mode 100644 index 0000000000..5ae473b899 --- /dev/null +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Result+AsyncTests.swift @@ -0,0 +1,54 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +@testable import AWSAPIPlugin + +class ResultAsyncTests: XCTestCase { + + func testFlatMapAsync_withSuccess_applyFunction() async { + func plus1(_ number: Int) async -> Int { + return number + 1 + } + + let result = Result.success(0) + let plus1Result = await result.flatMapAsync { + .success(await plus1($0)) + } + + switch plus1Result { + case .success(let plus1Result): + XCTAssertEqual(plus1Result, 1) + case .failure(let error): + XCTFail("Failed with error \(error)") + } + } + + func testFlatMapAsync_withFailure_notApplyFunction() async { + func arrayCount(_ array: [Int]) async -> Int { + return array.count + } + + + let expectedError = TestError() + let result = Result<[Int], Error>.failure(expectedError) + let count = await result.flatMapAsync { + .success(await arrayCount($0)) + } + + switch count { + case .success: + XCTFail("Should fail") + case .failure(let error): + XCTAssertTrue(error is TestError) + XCTAssertEqual(ObjectIdentifier(expectedError), ObjectIdentifier(error as! TestError)) + } + } +} + +fileprivate class TestError: Error { }