Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(auth): fix credential decoding #3938

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ extension AWSCognitoAuthCredentialStore: AmplifyAuthCredentialStoreBehavior {
func retrieveCredential() throws -> AmplifyCredentials {
let authCredentialStoreKey = generateSessionKey(for: authConfiguration)
let authCredentialData = try keychain._getData(authCredentialStoreKey)
let awsCredential: AmplifyCredentials = try decode(data: authCredentialData)
return awsCredential
let amplifyCredential: AmplifyCredentials = try decode(data: authCredentialData)
return amplifyCredential
}

func deleteCredential() throws {
Expand Down Expand Up @@ -191,15 +191,15 @@ private extension AWSCognitoAuthCredentialStore {
do {
return try JSONEncoder().encode(object)
} catch {
throw KeychainStoreError.codingError("Error occurred while encoding AWSCredentials", error)
throw KeychainStoreError.codingError("Error occurred while encoding credentials", error)
}
}

func decode<T: Decodable>(data: Data) throws -> T {
do {
return try JSONDecoder().decode(T.self, from: data)
} catch {
throw KeychainStoreError.codingError("Error occurred while decoding AWSCredentials", error)
throw KeychainStoreError.codingError("Error occurred while decoding credentials", error)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ public enum AuthFlowType {

internal init?(rawValue: String) {
switch rawValue {
case "CUSTOM_AUTH":
case "CUSTOM_AUTH", "CUSTOM_AUTH_WITH_SRP":
self = .customWithSRP
case "CUSTOM_AUTH_WITHOUT_SRP":
self = .customWithoutSRP
case "USER_SRP_AUTH":
self = .userSRP
case "USER_PASSWORD_AUTH":
Expand All @@ -51,8 +53,10 @@ public enum AuthFlowType {

var rawValue: String {
switch self {
case .custom, .customWithSRP, .customWithoutSRP:
return "CUSTOM_AUTH"
case .custom, .customWithSRP:
return "CUSTOM_AUTH_WITH_SRP"
case .customWithoutSRP:
return "CUSTOM_AUTH_WITHOUT_SRP"
case .userSRP:
return "USER_SRP_AUTH"
case .userPassword:
Expand All @@ -62,6 +66,24 @@ public enum AuthFlowType {
}
}

// This initializer has been added to migrate credentials that were created in the pre-passwordless era
internal static func legacyInit(rawValue: String) -> Self? {
switch rawValue {
case "userSRP":
return .userSRP
case "userPassword":
return .userPassword
case "custom":
return .custom
case "customWithSRP":
return .customWithSRP
case "customWithoutSRP":
return .customWithoutSRP
default:
return nil
}
}

public static var userAuth: AuthFlowType {
return .userAuth(preferredFirstFactor: nil)
}
Expand Down Expand Up @@ -110,27 +132,49 @@ extension AuthFlowType: Codable {

// Decoding the enum
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let container: KeyedDecodingContainer<CodingKeys>
do {
container = try decoder.container(keyedBy: CodingKeys.self)
} catch DecodingError.typeMismatch {
// The type mismatch has been added to handle a scenario where the user is migrating passwordless flows.
// Passwordless flow added a new enum case with a associated type.
// The association resulted in encoding structure changes that is different from the non-passwordless flows.
// The structure change causes the type mismatch exception and this code block tries to retrieve the legacy structure and decode it.
let legacyContainer = try decoder.singleValueContainer()
let type = try legacyContainer.decode(String.self)
guard let authFlowType = AuthFlowType.legacyInit(rawValue: type) else {
throw DecodingError.dataCorruptedError(in: legacyContainer, debugDescription: "Invalid AuthFlowType value")
}
self = authFlowType
return
} catch {
throw error
}

// Decode the type (raw value)
let type = try container.decode(String.self, forKey: .type)

// Initialize based on the type
switch type {
case "USER_SRP_AUTH":
self = .userSRP
case "CUSTOM_AUTH":
// Depending on your needs, choose either `.custom`, `.customWithSRP`, or `.customWithoutSRP`
// In this case, we'll default to `.custom`
self = .custom
case "CUSTOM_AUTH", "CUSTOM_AUTH_WITH_SRP":
self = .customWithSRP
case "CUSTOM_AUTH_WITHOUT_SRP":
self = .customWithoutSRP
case "USER_PASSWORD_AUTH":
self = .userPassword
case "USER_AUTH":
let preferredFirstFactorString = try container.decode(String.self, forKey: .preferredFirstFactor)
if let preferredFirstFactor = AuthFactorType(rawValue: preferredFirstFactorString) {
self = .userAuth(preferredFirstFactor: preferredFirstFactor)
if let preferredFirstFactorString = try container.decodeIfPresent(String.self, forKey: .preferredFirstFactor) {
if let preferredFirstFactor = AuthFactorType(rawValue: preferredFirstFactorString) {
self = .userAuth(preferredFirstFactor: preferredFirstFactor)
} else {
throw DecodingError.dataCorruptedError(
forKey: .preferredFirstFactor,
in: container,
debugDescription: "Unable to decode preferredFirstFactor value")
}
} else {
throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Unable to decode preferredFirstFactor value")
self = .userAuth(preferredFirstFactor: nil)
}
default:
throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Invalid AuthFlowType value")
Expand All @@ -152,5 +196,4 @@ extension AuthFlowType {
return .userAuth
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//


import XCTest
@testable import AWSCognitoAuthPlugin

class AuthFlowTypeTests: XCTestCase {

func testRawValue() {
XCTAssertEqual(AuthFlowType.userSRP.rawValue, "USER_SRP_AUTH")
XCTAssertEqual(AuthFlowType.customWithSRP.rawValue, "CUSTOM_AUTH_WITH_SRP")
XCTAssertEqual(AuthFlowType.customWithoutSRP.rawValue, "CUSTOM_AUTH_WITHOUT_SRP")
XCTAssertEqual(AuthFlowType.userPassword.rawValue, "USER_PASSWORD_AUTH")
XCTAssertEqual(AuthFlowType.userAuth(preferredFirstFactor: nil).rawValue, "USER_AUTH")
}

func testInitWithRawValue() {
XCTAssertEqual(AuthFlowType(rawValue: "USER_SRP_AUTH"), .userSRP)
XCTAssertEqual(AuthFlowType(rawValue: "CUSTOM_AUTH"), .customWithSRP)
XCTAssertEqual(AuthFlowType(rawValue: "CUSTOM_AUTH_WITH_SRP"), .customWithSRP)
XCTAssertEqual(AuthFlowType(rawValue: "CUSTOM_AUTH_WITHOUT_SRP"), .customWithoutSRP)
XCTAssertEqual(AuthFlowType(rawValue: "USER_PASSWORD_AUTH"), .userPassword)
XCTAssertEqual(AuthFlowType(rawValue: "USER_AUTH"), .userAuth(preferredFirstFactor: nil))
XCTAssertNil(AuthFlowType(rawValue: "INVALID_AUTH"))
}

func testDeprecatedCustom() {
// This test is to ensure the deprecated case is still functional
XCTAssertEqual(AuthFlowType.custom.rawValue, "CUSTOM_AUTH_WITH_SRP")
}

func testEncoding() throws {
let encoder = JSONEncoder()
let userSRP = try encoder.encode(AuthFlowType.userSRP)
XCTAssertEqual(String(data: userSRP, encoding: .utf8), "{\"type\":\"USER_SRP_AUTH\"}")

let customWithSRP = try encoder.encode(AuthFlowType.customWithSRP)
XCTAssertEqual(String(data: customWithSRP, encoding: .utf8), "{\"type\":\"CUSTOM_AUTH_WITH_SRP\"}")

let customWithoutSRP = try encoder.encode(AuthFlowType.customWithoutSRP)
XCTAssertEqual(String(data: customWithoutSRP, encoding: .utf8), "{\"type\":\"CUSTOM_AUTH_WITHOUT_SRP\"}")

let userPassword = try encoder.encode(AuthFlowType.userPassword)
XCTAssertEqual(String(data: userPassword, encoding: .utf8), "{\"type\":\"USER_PASSWORD_AUTH\"}")

let userAuth = try encoder.encode(AuthFlowType.userAuth(preferredFirstFactor: nil))
XCTAssertTrue(String(data: userAuth, encoding: .utf8)?.contains("\"preferredFirstFactor\":null") == true)
XCTAssertTrue(String(data: userAuth, encoding: .utf8)?.contains("\"type\":\"USER_AUTH\"") == true)
}

func testDecoding() throws {
let decoder = JSONDecoder()
let userSRP = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"USER_SRP_AUTH\"}".data(using: .utf8)!)
XCTAssertEqual(userSRP, .userSRP)

let customWithSRP = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"CUSTOM_AUTH_WITH_SRP\"}".data(using: .utf8)!)
XCTAssertEqual(customWithSRP, .customWithSRP)

let customWithoutSRP = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"CUSTOM_AUTH_WITHOUT_SRP\"}".data(using: .utf8)!)
XCTAssertEqual(customWithoutSRP, .customWithoutSRP)

let userPassword = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"USER_PASSWORD_AUTH\"}".data(using: .utf8)!)
XCTAssertEqual(userPassword, .userPassword)

let userAuth = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"USER_AUTH\"}".data(using: .utf8)!)
XCTAssertEqual(userAuth, .userAuth(preferredFirstFactor: nil))
}

func testDecodingWithPreferredFirstFactor() throws {
let decoder = JSONDecoder()
let json = """
{
"type": "USER_AUTH",
"preferredFirstFactor": "SMS_OTP"
}
""".data(using: .utf8)!
let authFlowType = try decoder.decode(AuthFlowType.self, from: json)
XCTAssertEqual(authFlowType, .userAuth(preferredFirstFactor: .smsOTP))
}

func testDecodingLegacyStructure() throws {
let decoder = JSONDecoder()
var legacyJson = "\"userSRP\"".data(using: .utf8)!
var authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .userSRP)

legacyJson = "\"userPassword\"".data(using: .utf8)!
authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .userPassword)

legacyJson = "\"customWithSRP\"".data(using: .utf8)!
authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .customWithSRP)

legacyJson = "\"customWithoutSRP\"".data(using: .utf8)!
authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .customWithoutSRP)

legacyJson = "\"custom\"".data(using: .utf8)!
authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .custom)
}

func testDecodingInvalidType() {
let decoder = JSONDecoder()
let invalidJson = "{\"type\":\"INVALID_AUTH\"}".data(using: .utf8)!
XCTAssertThrowsError(try decoder.decode(AuthFlowType.self, from: invalidJson)) { error in
guard case DecodingError.dataCorrupted(let context) = error else {
return XCTFail("Expected dataCorrupted error")
}
XCTAssertEqual(context.debugDescription, "Invalid AuthFlowType value")
}
}

func testDecodingInvalidPreferredFirstFactor() {
let decoder = JSONDecoder()
let invalidJson = """
{
"type": "USER_AUTH",
"preferredFirstFactor": "INVALID_FACTOR"
}
""".data(using: .utf8)!
XCTAssertThrowsError(try decoder.decode(AuthFlowType.self, from: invalidJson)) { error in
guard case DecodingError.dataCorrupted(let context) = error else {
return XCTFail("Expected dataCorrupted error")
}
XCTAssertEqual(context.debugDescription, "Unable to decode preferredFirstFactor value")
}
}

func testGetClientFlowType() {
XCTAssertEqual(AuthFlowType.custom.getClientFlowType(), .customAuth)
XCTAssertEqual(AuthFlowType.customWithSRP.getClientFlowType(), .customAuth)
XCTAssertEqual(AuthFlowType.customWithoutSRP.getClientFlowType(), .customAuth)
XCTAssertEqual(AuthFlowType.userSRP.getClientFlowType(), .userSrpAuth)
XCTAssertEqual(AuthFlowType.userPassword.getClientFlowType(), .userPasswordAuth)
XCTAssertEqual(AuthFlowType.userAuth(preferredFirstFactor: nil).getClientFlowType(), .userAuth)
}
}
Loading