Skip to content

Commit

Permalink
fix(datastore): using URLProtocol monitor multiAuth request headers
Browse files Browse the repository at this point in the history
  • Loading branch information
Di Wu committed Sep 14, 2023
1 parent 6f95db9 commit 6b7109f
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import XCTest
import Combine
import AWSDataStorePlugin
import AWSPluginsCore
import AWSAPIPlugin
@testable import AWSAPIPlugin
import AWSCognitoAuthPlugin

#if !os(watchOS)
Expand All @@ -23,55 +23,88 @@ struct TestUser {
let password: String
}

class AuthRecorderInterceptor: URLRequestInterceptor {
let awsAuthService: AWSAuthService = AWSAuthService()
var consumedAuthTypes: Set<AWSAuthorizationType> = []
private let accessQueue = DispatchQueue(label: "com.amazon.AuthRecorderInterceptor.consumedAuthTypes")

private func recordAuthType(_ authType: AWSAuthorizationType) {
accessQueue.async {
self.consumedAuthTypes.insert(authType)
}
}
class DataStoreAuthBaseTestURLSessionFactory: URLSessionBehaviorFactory {
static let testIdHeaderKey = "x-amplify-test"

func intercept(_ request: URLRequest) throws -> URLRequest {
guard let headers = request.allHTTPHeaderFields else {
fatalError("No headers found in request \(request)")
}
static let subject = PassthroughSubject<(String, Set<AWSAuthorizationType>), Never>()

let authHeaderValue = headers["Authorization"]
let apiKeyHeaderValue = headers["x-api-key"]
class Sniffer: URLProtocol {

if apiKeyHeaderValue != nil {
recordAuthType(.apiKey)
}
override class func canInit(with request: URLRequest) -> Bool {
guard let headers = request.allHTTPHeaderFields else {
fatalError("No headers found in request \(request)")
}

guard let testId = headers[DataStoreAuthBaseTestURLSessionFactory.testIdHeaderKey] else {
return false
}

var result: Set<AWSAuthorizationType> = []
let authHeaderValue = headers["Authorization"]
let apiKeyHeaderValue = headers["x-api-key"]

if apiKeyHeaderValue != nil {
result.insert(.apiKey)
}

if let authHeaderValue = authHeaderValue,
case let .success(claims) = AWSAuthService().getTokenClaims(tokenString: authHeaderValue),
let cognitoIss = claims["iss"] as? String, cognitoIss.contains("cognito") {
result.insert(.amazonCognitoUserPools)
}

if let authHeaderValue = authHeaderValue,
authHeaderValue.starts(with: "AWS4-HMAC-SHA256") {
result.insert(.awsIAM)
}

if let authHeaderValue = authHeaderValue,
case let .success(claims) = awsAuthService.getTokenClaims(tokenString: authHeaderValue),
let cognitoIss = claims["iss"] as? String, cognitoIss.contains("cognito") {
recordAuthType(.amazonCognitoUserPools)
DataStoreAuthBaseTestURLSessionFactory.subject.send((testId, result))
return false
}

if let authHeaderValue = authHeaderValue,
authHeaderValue.starts(with: "AWS4-HMAC-SHA256") {
recordAuthType(.awsIAM)
}

class Interceptor: URLRequestInterceptor {
let testId: String?

init(testId: String?) {
self.testId = testId
}

return request
func intercept(_ request: URLRequest) async throws -> URLRequest {
if let testId {
var mutableRequest = request
mutableRequest.setValue(testId, forHTTPHeaderField: DataStoreAuthBaseTestURLSessionFactory.testIdHeaderKey)
return mutableRequest
}
return request
}
}

func reset() {
consumedAuthTypes = []
func makeSession(withDelegate delegate: URLSessionBehaviorDelegate?) -> URLSessionBehavior {
let urlSessionDelegate = delegate?.asURLSessionDelegate
let configuration = URLSessionConfiguration.default
configuration.tlsMinimumSupportedProtocolVersion = .TLSv12
configuration.tlsMaximumSupportedProtocolVersion = .TLSv13
configuration.protocolClasses?.insert(Sniffer.self, at: 0)

let session = URLSession(configuration: configuration,
delegate: urlSessionDelegate,
delegateQueue: nil)
return AmplifyURLSession(session: session)
}


}


class AWSDataStoreAuthBaseTest: XCTestCase {
var requests: Set<AnyCancellable> = []

var amplifyConfig: AmplifyConfiguration!
var user1: TestUser?
var user2: TestUser?
var authRecorderInterceptor: AuthRecorderInterceptor!

override func setUp() {
continueAfterFailure = false
Expand Down Expand Up @@ -138,8 +171,6 @@ class AWSDataStoreAuthBaseTest: XCTestCase {
self.user1 = TestUser(username: user1, password: passwordUser1)
self.user2 = TestUser(username: user2, password: passwordUser2)

authRecorderInterceptor = AuthRecorderInterceptor()

amplifyConfig = try TestConfigHelper.retrieveAmplifyConfiguration(forResource: configFile)

} catch {
Expand All @@ -161,7 +192,8 @@ class AWSDataStoreAuthBaseTest: XCTestCase {
func setup(
withModels models: AmplifyModelRegistration,
testType: DataStoreAuthTestType,
apiPluginFactory: () -> AWSAPIPlugin = { AWSAPIPlugin(sessionFactory: AmplifyURLSessionFactory()) }
testId: String? = nil,
apiPluginFactory: () -> AWSAPIPlugin = { AWSAPIPlugin(sessionFactory: DataStoreAuthBaseTestURLSessionFactory()) }
) async {
do {
setupCredentials(forAuthStrategy: testType)
Expand All @@ -182,7 +214,10 @@ class AWSDataStoreAuthBaseTest: XCTestCase {

// register auth recorder interceptor
let apiName = try apiEndpointName()
try apiPlugin.add(interceptor: authRecorderInterceptor, for: apiName)
try apiPlugin.add(
interceptor: DataStoreAuthBaseTestURLSessionFactory.Interceptor(testId: testId),
for: apiName
)

await signOut()
} catch {
Expand Down Expand Up @@ -487,13 +522,27 @@ extension AWSDataStoreAuthBaseTest {
await waitForExpectations([expectations.mutationDelete, expectations.mutationDeleteProcessed], timeout: 60)
}

func assertUsedAuthTypes(_ authTypes: [AWSAuthorizationType],
file: StaticString = #file,
line: UInt = #line) {
XCTAssertEqual(authRecorderInterceptor.consumedAuthTypes,
Set(authTypes),
file: file,
line: line)
func assertUsedAuthTypes(
testId: String,
authTypes: [AWSAuthorizationType],
file: StaticString = #file,
line: UInt = #line
) -> XCTestExpectation {
let expectation = expectation(description: "Should have expected auth types")
expectation.assertForOverFulfill = false
DataStoreAuthBaseTestURLSessionFactory.subject
.filter { $0.0 == testId }
.map { $0.1 }
.collect(.byTime(DispatchQueue.global(), .milliseconds(3500)))
.sink {
let result = $0.reduce(Set<AWSAuthorizationType>()) { partialResult, data in
partialResult.union(data)
}
XCTAssertEqual(result, Set(authTypes), file: file, line: line)
expectation.fulfill()
}
.store(in: &requests)
return expectation
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class AWSDataStoreMultiAuthCombinationTests: AWSDataStoreAuthBaseTest {
/// Then: DataStore is successfully initialized.
func testDataStoreReadyState() async {
await setup(withModels: PrivatePublicComboModels(),
testType: .multiAuth)
testType: .multiAuth)
await signIn(user: user1)

let expectations = makeExpectations()
Expand Down Expand Up @@ -58,14 +58,17 @@ class AWSDataStoreMultiAuthCombinationTests: AWSDataStoreAuthBaseTest {
/// - DataStore is successfully initialized, sync/mutation/subscription network requests f
/// or PrivatePublicComboUPPost are sent with IAM auth for authenticated users.
func testOperationsForPrivatePublicComboUPPost() async {
let testId = UUID().uuidString
await setup(withModels: PrivatePublicComboModels(),
testType: .multiAuth)
testType: .multiAuth,
testId: testId)
await signIn(user: user1)

let expectations = makeExpectations()

await assertDataStoreReady(expectations)

let authTypeExpecation = assertUsedAuthTypes(testId: testId, authTypes: [.amazonCognitoUserPools])
// Query
await assertQuerySuccess(modelType: PrivatePublicComboUPPost.self,
expectations, onFailure: { error in
Expand All @@ -78,7 +81,7 @@ class AWSDataStoreMultiAuthCombinationTests: AWSDataStoreAuthBaseTest {
XCTFail("Error mutation \(error)")
}

assertUsedAuthTypes([.apiKey, .amazonCognitoUserPools])
await fulfillment(of: [authTypeExpecation], timeout: 5)
}

/// Given: a user signed in with API key
Expand All @@ -87,8 +90,12 @@ class AWSDataStoreMultiAuthCombinationTests: AWSDataStoreAuthBaseTest {
/// - DataStore is successfully initialized, sync/mutation/subscription network requests
/// for PrivatePublicComboAPIPost are sent with API key auth for authenticated users.
func testOperationsForPrivatePublicComboAPIPostAuthenticatedUser() async {
let testId = UUID().uuidString
let authTypeExpecation = assertUsedAuthTypes(testId: testId, authTypes: [.apiKey, .amazonCognitoUserPools])

await setup(withModels: PrivatePublicComboModels(),
testType: .multiAuth)
testType: .multiAuth,
testId: testId)
await signIn(user: user1)

let expectations = makeExpectations()
Expand All @@ -101,12 +108,14 @@ class AWSDataStoreMultiAuthCombinationTests: AWSDataStoreAuthBaseTest {
XCTFail("Error query \(error)")
})


// Mutation
await assertMutations(model: PrivatePublicComboAPIPost(name: "name"),
expectations) { error in
XCTFail("Error mutation \(error)")
}
assertUsedAuthTypes([.amazonCognitoUserPools, .apiKey])

await fulfillment(of: [authTypeExpecation], timeout: 5)
}

/// Given: an unauthenticated user
Expand All @@ -118,14 +127,17 @@ class AWSDataStoreMultiAuthCombinationTests: AWSDataStoreAuthBaseTest {
/// PrivatePublicComboUPPost does not sync for unauthenticated users, but it does not block the other models
/// from syncing and DataStore getting to a “ready” state.
func testOperationsForPrivatePublicComboAPIPost() async {
let testId = UUID().uuidString
await setup(withModels: PrivatePublicComboModels(),
testType: .multiAuth)
testType: .multiAuth,
testId: testId)

let expectations = makeExpectations()

// PrivatePublicComboUPPost won't sync for unauthenticated users
await assertDataStoreReady(expectations, expectedModelSynced: 1)

let authTypeExpecation = assertUsedAuthTypes(testId: testId, authTypes: [.apiKey])
// Query
await assertQuerySuccess(modelType: PrivatePublicComboAPIPost.self,
expectations, onFailure: { error in
Expand All @@ -137,7 +149,8 @@ class AWSDataStoreMultiAuthCombinationTests: AWSDataStoreAuthBaseTest {
expectations) { error in
XCTFail("Error mutation \(error)")
}
assertUsedAuthTypes([.apiKey])

await fulfillment(of: [authTypeExpecation], timeout: 5)
}
}

Expand Down
Loading

0 comments on commit 6b7109f

Please sign in to comment.