Skip to content

Commit

Permalink
change interceptors applying logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Di Wu committed Sep 6, 2023
1 parent 85d0fc8 commit 7c2b7ab
Show file tree
Hide file tree
Showing 16 changed files with 265 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,19 @@ public struct AWSAPICategoryPluginConfiguration {
interceptors[apiName]?.addInterceptor(interceptor)
}

/// Returns all the interceptors registered for `apiName` API endpoint
/// Returns all the customer defined interceptors registered for `apiName` API endpoint
/// - Parameter apiName: API endpoint name
/// - Returns: request interceptors
internal func interceptorsForEndpoint(named apiName: APIEndpointName) -> [URLRequestInterceptor] {
guard let interceptorsConfig = interceptors[apiName] else {
return []
}
return interceptorsConfig.interceptors
internal func interceptorsForEndpoint(named apiName: APIEndpointName) -> AWSAPIEndpointInterceptors? {
return interceptors[apiName]
}

/// Returns interceptors for the provided endpointConfig
/// Returns customer defined interceptors for the provided endpointConfig
/// - Parameters:
/// - endpointConfig: endpoint configuration
/// - Throws: PluginConfigurationError in case of failure building an instance of AWSAuthorizationConfiguration
/// - Returns: An array of URLRequestInterceptor
internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig) throws -> [URLRequestInterceptor] {
internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig) -> AWSAPIEndpointInterceptors? {
return interceptorsForEndpoint(named: endpointConfig.name)
}

Expand All @@ -109,8 +106,10 @@ public struct AWSAPICategoryPluginConfiguration {
/// - authType: overrides the registered auth interceptor
/// - Throws: PluginConfigurationError in case of failure building an instance of AWSAuthorizationConfiguration
/// - Returns: An array of URLRequestInterceptor
internal func interceptorsForEndpoint(withConfig endpointConfig: EndpointConfig,
authType: AWSAuthorizationType) throws -> [URLRequestInterceptor] {
internal func interceptorsForEndpoint(
withConfig endpointConfig: EndpointConfig,
authType: AWSAuthorizationType
) throws -> AWSAPIEndpointInterceptors? {

guard let apiAuthProviderFactory = self.apiAuthProviderFactory else {
return interceptorsForEndpoint(named: endpointConfig.name)
Expand All @@ -126,12 +125,10 @@ public struct AWSAPICategoryPluginConfiguration {
authConfiguration: authConfiguration)

// retrieve current interceptors and replace auth interceptor
let currentInterceptors = interceptorsForEndpoint(named: endpointConfig.name).filter {
!isAuthInterceptor($0)
}
config.interceptors.append(contentsOf: currentInterceptors)
let currentInterceptors = interceptorsForEndpoint(named: endpointConfig.name)
config.interceptors.append(contentsOf: currentInterceptors?.interceptors ?? [])

return config.interceptors
return config
}

// MARK: Private
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ struct AWSAPIEndpointInterceptors {
let apiAuthProviderFactory: APIAuthProviderFactory
let authService: AWSAuthServiceBehavior?

var amplifyInterceptors: [URLRequestInterceptor] = []

var checksumInterceptors: [URLRequestInterceptor] = []

var interceptors: [URLRequestInterceptor] = []

init(endpointName: APIEndpointName,
Expand All @@ -42,7 +46,7 @@ struct AWSAPIEndpointInterceptors {
case .apiKey(let apiKeyConfig):
let provider = BasicAPIKeyProvider(apiKey: apiKeyConfig.apiKey)
let interceptor = APIKeyURLRequestInterceptor(apiKeyProvider: provider)
addInterceptor(interceptor)
amplifyInterceptors.append(interceptor)
case .awsIAM(let iamConfig):
guard let authService = authService else {
throw PluginError.pluginConfigurationError("AuthService is not set for IAM",
Expand All @@ -52,31 +56,31 @@ struct AWSAPIEndpointInterceptors {
let interceptor = IAMURLRequestInterceptor(iamCredentialsProvider: provider,
region: iamConfig.region,
endpointType: endpointType)
addInterceptor(interceptor)
checksumInterceptors.append(interceptor)
case .amazonCognitoUserPools:
guard let authService = authService else {
throw PluginError.pluginConfigurationError("AuthService not set for cognito user pools",
"")
}
let provider = BasicUserPoolTokenProvider(authService: authService)
let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: provider)
addInterceptor(interceptor)
amplifyInterceptors.append(interceptor)
case .openIDConnect:
guard let oidcAuthProvider = apiAuthProviderFactory.oidcAuthProvider() else {
throw PluginError.pluginConfigurationError("AuthService not set for OIDC",
"Provide an AmplifyOIDCAuthProvider via API plugin configuration")
}
let wrappedAuthProvider = AuthTokenProviderWrapper(tokenAuthProvider: oidcAuthProvider)
let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: wrappedAuthProvider)
addInterceptor(interceptor)
amplifyInterceptors.append(interceptor)
case .function:
guard let functionAuthProvider = apiAuthProviderFactory.functionAuthProvider() else {
throw PluginError.pluginConfigurationError("AuthService not set for function auth",
"Provide an AmplifyFunctionAuthProvider via API plugin configuration")
}
let wrappedAuthProvider = AuthTokenProviderWrapper(tokenAuthProvider: functionAuthProvider)
let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: wrappedAuthProvider)
addInterceptor(interceptor)
amplifyInterceptors.append(interceptor)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ struct AuthTokenURLRequestInterceptor: URLRequestInterceptor {

mutableRequest.setValue(amzDate,
forHTTPHeaderField: URLRequestConstants.Header.xAmzDate)
mutableRequest.setValue(URLRequestConstants.ContentType.applicationJson,
forHTTPHeaderField: URLRequestConstants.Header.contentType)
mutableRequest.setValue(userAgent,
mutableRequest.addValue(userAgent,
forHTTPHeaderField: URLRequestConstants.Header.userAgent)

let token: String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ struct IAMURLRequestInterceptor: URLRequestInterceptor {
throw APIError.unknown("Could not get host from mutable request", "")
}

request.setValue(URLRequestConstants.ContentType.applicationJson, forHTTPHeaderField: URLRequestConstants.Header.contentType)
request.setValue(host, forHTTPHeaderField: "host")
request.setValue(userAgent, forHTTPHeaderField: URLRequestConstants.Header.userAgent)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ final public class AWSGraphQLOperation<R: Decodable>: GraphQLOperation<R> {

// Retrieve endpoint configuration
let endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig
let requestInterceptors: [URLRequestInterceptor]
let requestInterceptors: AWSAPIEndpointInterceptors?

do {
endpointConfig = try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .graphQL)
Expand All @@ -66,7 +66,7 @@ final public class AWSGraphQLOperation<R: Decodable>: GraphQLOperation<R> {
requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig,
authType: authType)
} else {
requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig)
requestInterceptors = pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig)
}
} catch let error as APIError {
dispatch(result: .failure(error))
Expand Down Expand Up @@ -101,34 +101,67 @@ final public class AWSGraphQLOperation<R: Decodable>: GraphQLOperation<R> {
// Create request
let urlRequest = GraphQLOperationRequestUtils.constructRequest(with: endpointConfig.baseURL,
requestPayload: requestPayload)

let amplifyInterceptors = requestInterceptors?.amplifyInterceptors ?? []
let customerInterceptors = requestInterceptors?.interceptors ?? []
let checksumInterceptors = requestInterceptors?.checksumInterceptors ?? []
Task {
// Intercept request
var finalRequest = urlRequest
for interceptor in requestInterceptors {
do {
finalRequest = try await interceptor.intercept(finalRequest)
} catch let error as APIError {
dispatch(result: .failure(error))
cancel()
} catch {
dispatch(result: .failure(APIError.operationError("Failed to intercept request fully.",
"Something wrong with the interceptor",
error)))
cancel()
var finalResult: Result<URLRequest, APIError> = .success(urlRequest)
// apply amplify interceptors
for interceptor in amplifyInterceptors {
finalResult = await finalResult.flatMapAsync { request in
await applyInterceptor(interceptor, request: request)
}
}

// there is no customer headers for GraphQLOperationRequest

// apply customer interceptors
for interceptor in customerInterceptors {
finalResult = await finalResult.flatMapAsync { request in
await applyInterceptor(interceptor, request: request)
}
}

// apply checksum interceptor
for interceptor in checksumInterceptors {
finalResult = await finalResult.flatMapAsync { request in
await applyInterceptor(interceptor, request: request)
}
}

if isCancelled {
finish()
return
switch finalResult {
case .success(let finalRequest):
if isCancelled {
finish()
return
}

// Begin network task
Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)")
let task = session.dataTaskBehavior(with: finalRequest)
mapper.addPair(operation: self, task: task)
task.resume()
case .failure(let error):
dispatch(result: .failure(error))
cancel()
}
}
}

// Begin network task
Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)")
let task = session.dataTaskBehavior(with: finalRequest)
mapper.addPair(operation: self, task: task)
task.resume()
private func applyInterceptor(_ interceptor: URLRequestInterceptor, request: URLRequest) async -> Result<URLRequest, APIError> {
do {
return .success(try await interceptor.intercept(request))
} catch let error as APIError {
return .failure(error)
} catch {
return .failure(
APIError.operationError(
"Failed to intercept request fully.",
"Something wrong with the interceptor",
error
)
)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,15 @@ final public class AWSRESTOperation: AmplifyOperation<

// Retrieve endpoint configuration
let endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig
let requestInterceptors: [URLRequestInterceptor]
let amplifyInterceptors: [URLRequestInterceptor]
let customerInterceptors: [URLRequestInterceptor]
let checksumInterceptors: [URLRequestInterceptor]
do {
endpointConfig = try pluginConfig.endpoints.getConfig(for: request.apiName, endpointType: .rest)
requestInterceptors = try pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig)
let interceptorConfig = pluginConfig.interceptorsForEndpoint(withConfig: endpointConfig)
amplifyInterceptors = interceptorConfig?.amplifyInterceptors ?? []
customerInterceptors = interceptorConfig?.interceptors ?? []
checksumInterceptors = interceptorConfig?.checksumInterceptors ?? []
} catch let error as APIError {
dispatch(result: .failure(error))
finish()
Expand Down Expand Up @@ -94,41 +99,76 @@ final public class AWSRESTOperation: AmplifyOperation<
}

// Construct URL Request with url and request body
let urlRequest = RESTOperationRequestUtils.constructURLRequest(with: url,
operationType: request.operationType,
headers: request.headers,
requestPayload: request.body)
let urlRequest = RESTOperationRequestUtils.constructURLRequest(
with: url,
operationType: request.operationType,
requestPayload: request.body
)

Task {
// Intercept request
var finalRequest = urlRequest
for interceptor in requestInterceptors {
do {
finalRequest = try await interceptor.intercept(finalRequest)
} catch let error as APIError {
dispatch(result: .failure(error))
cancel()
} catch {
dispatch(result: .failure(APIError.operationError("Failed to intercept request fully.",
"Something wrong with the interceptor",
error)))
cancel()
var finalResult: Result<URLRequest, APIError> = .success(urlRequest)
// apply amplify interceptors
for interceptor in amplifyInterceptors {
finalResult = await finalResult.flatMapAsync { request in
await applyInterceptor(interceptor, request: request)
}
}

// The headers from the request object should override any of the same header modifications done by the intercepters above
finalRequest = RESTOperationRequestUtils.applyCustomizeRequestHeaders(request.headers, on: finalRequest)
// apply customer headers
finalResult = finalResult.map { urlRequest in
var mutableRequest = urlRequest
for (key, value) in request.headers ?? [:] {
mutableRequest.setValue(value, forHTTPHeaderField: key)
}
return mutableRequest
}

// apply customer interceptors
for interceptor in customerInterceptors {
finalResult = await finalResult.flatMapAsync { request in
await applyInterceptor(interceptor, request: request)
}
}

// apply checksum interceptor
for interceptor in checksumInterceptors {
finalResult = await finalResult.flatMapAsync { request in
await applyInterceptor(interceptor, request: request)
}
}

if isCancelled {
finish()
return
switch finalResult {
case .success(let finalRequest):
if isCancelled {
finish()
return
}

// Begin network task
Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)")
let task = session.dataTaskBehavior(with: finalRequest)
mapper.addPair(operation: self, task: task)
task.resume()
case .failure(let error):
dispatch(result: .failure(error))
cancel()
}
}
}

// Begin network task
Amplify.API.log.debug("Starting network task for \(request.operationType) \(id)")
let task = session.dataTaskBehavior(with: finalRequest)
mapper.addPair(operation: self, task: task)
task.resume()
private func applyInterceptor(_ interceptor: URLRequestInterceptor, request: URLRequest) async -> Result<URLRequest, APIError> {
do {
return .success(try await interceptor.intercept(request))
} catch let error as APIError {
return .failure(error)
} catch {
return .failure(
APIError.operationError(
"Failed to intercept request fully.",
"Something wrong with the interceptor",
error
)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class GraphQLOperationRequestUtils {
// Construct a graphQL specific HTTP POST request with the request payload
static func constructRequest(with baseUrl: URL, requestPayload: Data) -> URLRequest {
var baseRequest = URLRequest(url: baseUrl)
let headers = ["content-type": "application/json", "Cache-Control": "no-store"]
baseRequest.allHTTPHeaderFields = headers
baseRequest.setValue("application/json", forHTTPHeaderField: "content-type")
baseRequest.setValue("no-store", forHTTPHeaderField: "cache-control")
baseRequest.httpMethod = "POST"
baseRequest.httpBody = requestPayload

Expand Down
Loading

0 comments on commit 7c2b7ab

Please sign in to comment.