Skip to content

Commit

Permalink
fix(datastore): multi auth rule for read subscription (#3316)
Browse files Browse the repository at this point in the history
* fix(datastore): multi auth rule for read subscription

* Address review comments
  • Loading branch information
thisisabhash authored Nov 3, 2023
1 parent 2604241 commit d4b957d
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ public struct AuthRule {
self.operations = operations
}
}

extension AuthRule: Hashable {

}
33 changes: 28 additions & 5 deletions AmplifyPlugins/Core/AWSPluginsCore/Auth/AWSAuthModeStrategy.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public protocol AuthModeStrategy: AnyObject {
init()

func authTypesFor(schema: ModelSchema, operation: ModelOperation) async -> AWSAuthorizationTypeIterator

func authTypesFor(schema: ModelSchema, operations: [ModelOperation]) async -> AWSAuthorizationTypeIterator
}

/// AuthorizationType iterator with an extra `count` property used
Expand Down Expand Up @@ -93,6 +95,11 @@ public class AWSDefaultAuthModeStrategy: AuthModeStrategy {
operation: ModelOperation) -> AWSAuthorizationTypeIterator {
return AWSAuthorizationTypeIterator(withValues: [])
}

public func authTypesFor(schema: ModelSchema,
operations: [ModelOperation]) -> AWSAuthorizationTypeIterator {
return AWSAuthorizationTypeIterator(withValues: [])
}
}

// MARK: - AWSMultiAuthModeStrategy
Expand Down Expand Up @@ -188,19 +195,35 @@ public class AWSMultiAuthModeStrategy: AuthModeStrategy {
/// - Returns: an iterator for the applicable auth rules
public func authTypesFor(schema: ModelSchema,
operation: ModelOperation) async -> AWSAuthorizationTypeIterator {
var applicableAuthRules = schema.authRules
.filter(modelOperation: operation)
return await authTypesFor(schema: schema, operations: [operation])
}

/// Returns the union of authorization types for the provided schema for the given list of operations
/// - Parameters:
/// - schema: model schema
/// - operations: model operations
/// - Returns: an iterator for the applicable auth rules
public func authTypesFor(schema: ModelSchema,
operations: [ModelOperation]) async -> AWSAuthorizationTypeIterator {
var sortedRules = operations
.flatMap { schema.authRules.filter(modelOperation: $0) }
.reduce(into: [AuthRule](), { array, rule in
if !array.contains(rule) {
array.append(rule)
}
})
.sorted(by: AWSMultiAuthModeStrategy.comparator)

// if there isn't a user signed in, returns only public or custom rules
if let authDelegate = authDelegate, await !authDelegate.isUserLoggedIn() {
applicableAuthRules = applicableAuthRules.filter { rule in
sortedRules = sortedRules.filter { rule in
return rule.allow == .public || rule.allow == .custom
}
}
let applicableAuthTypes = applicableAuthRules.map {
let applicableAuthTypes = sortedRules.map {
AWSMultiAuthModeStrategy.authTypeFor(authRule: $0)
}
return AWSAuthorizationTypeIterator(withValues: applicableAuthTypes)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ class AuthModeStrategyTests: XCTestCase {
let authMode = AWSMultiAuthModeStrategy()
let delegate = UnauthenticatedUserDelegate()
authMode.authDelegate = delegate

var authTypesIterator = await authMode.authTypesFor(schema: ModelWithOwnerAndPublicAuth.schema,
operation: .create)
operation: .create)
XCTAssertEqual(authTypesIterator.count, 1)
XCTAssertEqual(authTypesIterator.next(), .apiKey)
}
Expand All @@ -101,7 +101,7 @@ class AuthModeStrategyTests: XCTestCase {
func testMultiAuthPriorityWithCustomStrategy() async {
let authMode = AWSMultiAuthModeStrategy()
var authTypesIterator = await authMode.authTypesFor(schema: ModelWithCustomStrategy.schema,
operation: .create)
operation: .create)
XCTAssertEqual(authTypesIterator.count, 3)
XCTAssertEqual(authTypesIterator.next(), .function)
XCTAssertEqual(authTypesIterator.next(), .amazonCognitoUserPools)
Expand All @@ -117,12 +117,35 @@ class AuthModeStrategyTests: XCTestCase {
authMode.authDelegate = delegate

var authTypesIterator = await authMode.authTypesFor(schema: ModelWithCustomStrategy.schema,
operation: .create)
operation: .create)
XCTAssertEqual(authTypesIterator.count, 2)
XCTAssertEqual(authTypesIterator.next(), .function)
XCTAssertEqual(authTypesIterator.next(), .awsIAM)
}

// Given: multi-auth strategy and a model schema without auth provider
// When: auth types are requested with multiple operation
// Then: default values based on the auth strategy should be returned
func testMultiAuthShouldReturnDefaultAuthTypesForMultipleOperation() async {
let authMode = AWSMultiAuthModeStrategy()
var authTypesIterator = await authMode.authTypesFor(schema: ModelNoProvider.schema, operations: [.read, .create])
XCTAssertEqual(authTypesIterator.count, 2)
XCTAssertEqual(authTypesIterator.next(), .amazonCognitoUserPools)
XCTAssertEqual(authTypesIterator.next(), .apiKey)
}

// Given: multi-auth strategy and a model schema with auth provider
// When: auth types are requested with multiple operation
// Then: auth rule for public access should be returned
func testMultiAuthReturnDefaultAuthTypesForMultipleOperationWithProvider() async {
let authMode = AWSMultiAuthModeStrategy()
let delegate = UnauthenticatedUserDelegate()
authMode.authDelegate = delegate
var authTypesIterator = await authMode.authTypesFor(schema: ModelNoProvider.schema, operations: [.read, .create])
XCTAssertEqual(authTypesIterator.count, 1)
XCTAssertEqual(authTypesIterator.next(), .apiKey)
}

}

// MARK: - Test models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
// onCreate operation
let onCreateValueListener = onCreateValueListenerHandler(event:)
let onCreateAuthTypeProvider = await authModeStrategy.authTypesFor(schema: modelSchema,
operation: .create)
operations: [.create, .read])
self.onCreateValueListener = onCreateValueListener
self.onCreateOperation = RetryableGraphQLSubscriptionOperation(
requestFactory: IncomingAsyncSubscriptionEventPublisher.apiRequestFactoryFor(
Expand All @@ -94,7 +94,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
// onUpdate operation
let onUpdateValueListener = onUpdateValueListenerHandler(event:)
let onUpdateAuthTypeProvider = await authModeStrategy.authTypesFor(schema: modelSchema,
operation: .update)
operations: [.update, .read])
self.onUpdateValueListener = onUpdateValueListener
self.onUpdateOperation = RetryableGraphQLSubscriptionOperation(
requestFactory: IncomingAsyncSubscriptionEventPublisher.apiRequestFactoryFor(
Expand All @@ -115,7 +115,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
// onDelete operation
let onDeleteValueListener = onDeleteValueListenerHandler(event:)
let onDeleteAuthTypeProvider = await authModeStrategy.authTypesFor(schema: modelSchema,
operation: .delete)
operations: [.delete, .read])
self.onDeleteValueListener = onDeleteValueListener
self.onDeleteOperation = RetryableGraphQLSubscriptionOperation(
requestFactory: IncomingAsyncSubscriptionEventPublisher.apiRequestFactoryFor(
Expand Down

0 comments on commit d4b957d

Please sign in to comment.