Skip to content

Commit

Permalink
fix(api): change request interceptors applying logic (#3190)
Browse files Browse the repository at this point in the history
* fix(api): apply cutomize request headers after interceptors

* add integration test case

* change interceptors applying logic

* refactor code style

* refactor code style for rest operations

* add unit test case for customer header override

* update interceptor names

* update interceptor doc comment
  • Loading branch information
Di Wu authored Sep 12, 2023
1 parent 22ea491 commit 6f95db9
Show file tree
Hide file tree
Showing 19 changed files with 430 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public struct AWSAPICategoryPluginConfiguration {
self.authService = authService
}

/// Registers an interceptor for the provided API endpoint
/// Registers an customer interceptor for the provided API endpoint
/// - Parameter interceptor: operation interceptor used to decorate API requests
/// - Parameter toEndpoint: API endpoint name
mutating func addInterceptor(_ interceptor: URLRequestInterceptor,
Expand All @@ -86,20 +86,16 @@ public struct AWSAPICategoryPluginConfiguration {

/// Returns all the interceptors registered for `apiName` API endpoint
/// - Parameter apiName: API endpoint name
/// - Returns: request interceptors
internal func interceptorsForEndpoint(named apiName: APIEndpointName) -> [URLRequestInterceptor] {
guard let interceptorsConfig = interceptors[apiName] else {
return []
}
return interceptorsConfig.interceptors
/// - Returns: Optional AWSAPIEndpointInterceptors for the apiName
internal func interceptorsForEndpoint(named apiName: APIEndpointName) -> AWSAPIEndpointInterceptors? {
return interceptors[apiName]
}

/// Returns interceptors for the provided endpointConfig
/// Returns the interceptors for the provided endpointConfig
/// - Parameters:
/// - endpointConfig: endpoint configuration
/// - Throws: PluginConfigurationError in case of failure building an instance of AWSAuthorizationConfiguration
/// - Returns: An array of URLRequestInterceptor
internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig) throws -> [URLRequestInterceptor] {
/// - Returns: Optional AWSAPIEndpointInterceptors for the endpointConfig
internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig) -> AWSAPIEndpointInterceptors? {
return interceptorsForEndpoint(named: endpointConfig.name)
}

Expand All @@ -108,9 +104,11 @@ public struct AWSAPICategoryPluginConfiguration {
/// - endpointConfig: endpoint configuration
/// - authType: overrides the registered auth interceptor
/// - Throws: PluginConfigurationError in case of failure building an instance of AWSAuthorizationConfiguration
/// - Returns: An array of URLRequestInterceptor
internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig,
authType: AWSAuthorizationType) throws -> [URLRequestInterceptor] {
/// - Returns: Optional AWSAPIEndpointInterceptors for the endpointConfig and authType
internal func interceptorsForEndpoint(
withConfig endpointConfig: EndpointConfig,
authType: AWSAuthorizationType
) throws -> AWSAPIEndpointInterceptors? {

guard let apiAuthProviderFactory = self.apiAuthProviderFactory else {
return interceptorsForEndpoint(named: endpointConfig.name)
Expand All @@ -126,12 +124,10 @@ public struct AWSAPICategoryPluginConfiguration {
authConfiguration: authConfiguration)

// retrieve current interceptors and replace auth interceptor
let currentInterceptors = interceptorsForEndpoint(named: endpointConfig.name).filter {
!isAuthInterceptor($0)
}
config.interceptors.append(contentsOf: currentInterceptors)
let currentInterceptors = interceptorsForEndpoint(named: endpointConfig.name)
config.interceptors.append(contentsOf: currentInterceptors?.interceptors ?? [])

return config.interceptors
return config
}

// MARK: Private
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,27 @@ import Amplify
import Foundation
import AWSPluginsCore

/// The order of interceptor decoration is as follows:
/// 1. **prelude interceptors**
/// 2. **cutomize headers**
/// 3. **customer interceptors**
/// 4. **postlude interceptors**
///
/// **Prelude** and **postlude** interceptors are used by library maintainers to
/// integrate essential functionality for a variety of authentication types.
struct AWSAPIEndpointInterceptors {
// API name
let apiEndpointName: APIEndpointName

let apiAuthProviderFactory: APIAuthProviderFactory
let authService: AWSAuthServiceBehavior?

var preludeInterceptors: [URLRequestInterceptor] = []

var interceptors: [URLRequestInterceptor] = []

var postludeInterceptors: [URLRequestInterceptor] = []

init(endpointName: APIEndpointName,
apiAuthProviderFactory: APIAuthProviderFactory,
authService: AWSAuthServiceBehavior? = nil) {
Expand All @@ -42,7 +54,7 @@ struct AWSAPIEndpointInterceptors {
case .apiKey(let apiKeyConfig):
let provider = BasicAPIKeyProvider(apiKey: apiKeyConfig.apiKey)
let interceptor = APIKeyURLRequestInterceptor(apiKeyProvider: provider)
addInterceptor(interceptor)
preludeInterceptors.append(interceptor)
case .awsIAM(let iamConfig):
guard let authService = authService else {
throw PluginError.pluginConfigurationError("AuthService is not set for IAM",
Expand All @@ -52,31 +64,31 @@ struct AWSAPIEndpointInterceptors {
let interceptor = IAMURLRequestInterceptor(iamCredentialsProvider: provider,
region: iamConfig.region,
endpointType: endpointType)
addInterceptor(interceptor)
postludeInterceptors.append(interceptor)
case .amazonCognitoUserPools:
guard let authService = authService else {
throw PluginError.pluginConfigurationError("AuthService not set for cognito user pools",
"")
}
let provider = BasicUserPoolTokenProvider(authService: authService)
let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: provider)
addInterceptor(interceptor)
preludeInterceptors.append(interceptor)
case .openIDConnect:
guard let oidcAuthProvider = apiAuthProviderFactory.oidcAuthProvider() else {
throw PluginError.pluginConfigurationError("AuthService not set for OIDC",
"Provide an AmplifyOIDCAuthProvider via API plugin configuration")
}
let wrappedAuthProvider = AuthTokenProviderWrapper(tokenAuthProvider: oidcAuthProvider)
let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: wrappedAuthProvider)
addInterceptor(interceptor)
preludeInterceptors.append(interceptor)
case .function:
guard let functionAuthProvider = apiAuthProviderFactory.functionAuthProvider() else {
throw PluginError.pluginConfigurationError("AuthService not set for function auth",
"Provide an AmplifyFunctionAuthProvider via API plugin configuration")
}
let wrappedAuthProvider = AuthTokenProviderWrapper(tokenAuthProvider: functionAuthProvider)
let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: wrappedAuthProvider)
addInterceptor(interceptor)
preludeInterceptors.append(interceptor)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ struct AuthTokenURLRequestInterceptor: URLRequestInterceptor {

mutableRequest.setValue(amzDate,
forHTTPHeaderField: URLRequestConstants.Header.xAmzDate)
mutableRequest.setValue(URLRequestConstants.ContentType.applicationJson,
forHTTPHeaderField: URLRequestConstants.Header.contentType)
mutableRequest.setValue(userAgent,
mutableRequest.addValue(userAgent,
forHTTPHeaderField: URLRequestConstants.Header.userAgent)

let token: String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ struct IAMURLRequestInterceptor: URLRequestInterceptor {
throw APIError.unknown("Could not get host from mutable request", "")
}

request.setValue(URLRequestConstants.ContentType.applicationJson, forHTTPHeaderField: URLRequestConstants.Header.contentType)
request.setValue(host, forHTTPHeaderField: "host")
request.setValue(userAgent, forHTTPHeaderField: URLRequestConstants.Header.userAgent)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,50 +34,90 @@ final public class AWSGraphQLOperation<R: Decodable>: GraphQLOperation<R> {
}

override public func main() {
Task { await mainAsync() }
}

private func mainAsync() async {
Amplify.API.log.debug("Starting \(request.operationType) \(id)")

if isCancelled {
finish()
return
}

// Validate the request
do {
try request.validate()
} catch let error as APIError {
let urlRequest = validateRequest(request).flatMap(buildURLRequest(from:))
let finalRequest = await getEndpointInterceptors(from: request).flatMapAsync { requestInterceptors in
let preludeInterceptors = requestInterceptors?.preludeInterceptors ?? []
let customerInterceptors = requestInterceptors?.interceptors ?? []
let postludeInterceptors = requestInterceptors?.postludeInterceptors ?? []

var finalResult = urlRequest
// apply prelude interceptors
for interceptor in preludeInterceptors {
finalResult = await finalResult.flatMapAsync { request in
await applyInterceptor(interceptor, request: request)
}
}

// there is no customize headers for GraphQLOperationRequest

// apply customer interceptors
for interceptor in customerInterceptors {
finalResult = await finalResult.flatMapAsync { request in
await applyInterceptor(interceptor, request: request)
}
}

// apply postlude interceptor
for interceptor in postludeInterceptors {
finalResult = await finalResult.flatMapAsync { request in
await applyInterceptor(interceptor, request: request)
}
}
return finalResult
}

switch finalRequest {
case .success(let finalRequest):
if isCancelled {
finish()
return
}

// Begin network task
Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)")
let task = session.dataTaskBehavior(with: finalRequest)
mapper.addPair(operation: self, task: task)
task.resume()
case .failure(let error):
dispatch(result: .failure(error))
finish()
return
} catch {
dispatch(result: .failure(APIError.unknown("Could not validate request", "", nil)))
finish()
return
}
}

// Retrieve endpoint configuration
let endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig
let requestInterceptors: [URLRequestInterceptor]

private func validateRequest(_ request: GraphQLOperationRequest<R>) -> Result<GraphQLOperationRequest<R>, APIError> {
do {
endpointConfig = try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .graphQL)

if let pluginOptions = request.options.pluginOptions as? AWSPluginOptions,
let authType = pluginOptions.authType {
requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig,
authType: authType)
} else {
requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig)
}
try request.validate()
return .success(request)
} catch let error as APIError {
dispatch(result: .failure(error))
finish()
return
return .failure(error)
} catch {
dispatch(result: .failure(APIError.unknown("Could not get endpoint configuration", "", nil)))
finish()
return
return .failure(APIError.unknown("Could not validate request", "", nil))
}
}

private func buildURLRequest(from request: GraphQLOperationRequest<R>) -> Result<URLRequest, APIError> {
getEndpointConfig(from: request).flatMap { endpointConfig in
getRequestPayload(from: request).map { requestPayload in
GraphQLOperationRequestUtils.constructRequest(
with: endpointConfig.baseURL,
requestPayload: requestPayload
)
}
}
}

private func getRequestPayload(from request: GraphQLOperationRequest<R>) -> Result<Data, APIError> {
// Prepare request payload
let queryDocument = GraphQLOperationRequestUtils.getQueryDocument(document: request.document,
variables: request.variables)
Expand All @@ -87,48 +127,64 @@ final public class AWSGraphQLOperation<R: Decodable>: GraphQLOperation<R> {
let prettyPrintedQueryDocument = String(data: serializedJSON, encoding: .utf8) {
Amplify.API.log.verbose("\(prettyPrintedQueryDocument)")
}
let requestPayload: Data

do {
requestPayload = try JSONSerialization.data(withJSONObject: queryDocument)
return .success(try JSONSerialization.data(withJSONObject: queryDocument))
} catch {
dispatch(result: .failure(APIError.operationError("Failed to serialize query document",
"fix the document or variables",
error)))
finish()
return
return .failure(APIError.operationError(
"Failed to serialize query document",
"fix the document or variables",
error
))
}
}

// Create request
let urlRequest = GraphQLOperationRequestUtils.constructRequest(with: endpointConfig.baseURL,
requestPayload: requestPayload)

Task {
// Intercept request
var finalRequest = urlRequest
for interceptor in requestInterceptors {
do {
finalRequest = try await interceptor.intercept(finalRequest)
} catch let error as APIError {
dispatch(result: .failure(error))
cancel()
} catch {
dispatch(result: .failure(APIError.operationError("Failed to intercept request fully.",
"Something wrong with the interceptor",
error)))
cancel()
}
}
private func getEndpointConfig(from request: GraphQLOperationRequest<R>) -> Result<AWSAPICategoryPluginConfiguration.EndpointConfig, APIError> {
do {
return .success(try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .graphQL))
} catch let error as APIError {
return .failure(error)

if isCancelled {
finish()
return
} catch {
return .failure(APIError.unknown("Could not get endpoint configuration", "", nil))
}
}

private func getEndpointInterceptors(from request: GraphQLOperationRequest<R>) -> Result<AWSAPIEndpointInterceptors?, APIError> {
getEndpointConfig(from: request).flatMap { endpointConfig in
do {
if let pluginOptions = request.options.pluginOptions as? AWSPluginOptions,
let authType = pluginOptions.authType
{
return .success(try pluginConfig.interceptorsForEndpoint(
withConfig: endpointConfig,
authType: authType
))
} else {
return .success(pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig))
}
} catch let error as APIError {
return .failure(error)
} catch {
return .failure(APIError.unknown("Could not get endpoint interceptors", "", nil))
}
}
}

// Begin network task
Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)")
let task = session.dataTaskBehavior(with: finalRequest)
mapper.addPair(operation: self, task: task)
task.resume()
private func applyInterceptor(_ interceptor: URLRequestInterceptor, request: URLRequest) async -> Result<URLRequest, APIError> {
do {
return .success(try await interceptor.intercept(request))
} catch let error as APIError {
return .failure(error)
} catch {
return .failure(
APIError.operationError(
"Failed to intercept request with \(type(of: interceptor)). Error message: \(error.localizedDescription).",
"See underlying error for more details",
error
)
)
}
}

}
Loading

0 comments on commit 6f95db9

Please sign in to comment.