Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(auth): Add userAttributes to confirmSignIn call #2640

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -785,12 +785,15 @@ internal class RealAWSCognitoAuthPlugin(
},
{
val awsCognitoConfirmSignInOptions = options as? AWSCognitoAuthConfirmSignInOptions
val metadata = awsCognitoConfirmSignInOptions?.metadata ?: emptyMap()
val userAttributes = awsCognitoConfirmSignInOptions?.userAttributes ?: emptyList()
when (signInState) {
is SignInState.ResolvingChallenge -> {
val event = SignInChallengeEvent(
SignInChallengeEvent.EventType.VerifyChallengeAnswer(
challengeResponse,
awsCognitoConfirmSignInOptions?.metadata ?: mapOf()
metadata,
userAttributes
)
)
authStateMachine.send(event)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.amplifyframework.auth.cognito.actions
import aws.sdk.kotlin.services.cognitoidentityprovider.model.ChallengeNameType
import aws.sdk.kotlin.services.cognitoidentityprovider.model.ResourceNotFoundException
import aws.sdk.kotlin.services.cognitoidentityprovider.respondToAuthChallenge
import com.amplifyframework.auth.AuthUserAttribute
import com.amplifyframework.auth.cognito.AuthEnvironment
import com.amplifyframework.auth.cognito.helpers.AuthHelper
import com.amplifyframework.auth.cognito.helpers.SignInChallengeHelper
Expand All @@ -32,9 +33,11 @@ import com.amplifyframework.statemachine.codegen.events.SignInChallengeEvent
internal object SignInChallengeCognitoActions : SignInChallengeActions {
private const val KEY_SECRET_HASH = "SECRET_HASH"
private const val KEY_USERNAME = "USERNAME"
private const val KEY_PREFIX_USER_ATTRIBUTE = "userAttributes."
override fun verifyChallengeAuthAction(
answer: String,
metadata: Map<String, String>,
attributes: List<AuthUserAttribute>,
challenge: AuthChallenge
): Action = Action<AuthEnvironment>("VerifySignInChallenge") { id, dispatcher ->
logger.verbose("$id Starting execution")
Expand All @@ -50,6 +53,12 @@ internal object SignInChallengeCognitoActions : SignInChallengeActions {
challengeResponses[responseKey] = answer
}

challengeResponses.putAll(
attributes.map {
Pair("${KEY_PREFIX_USER_ATTRIBUTE}${it.key.keyString}", it.value)
}
)

val secretHash = AuthHelper.getSecretHash(
username,
configuration.userPool?.appClient,
Expand Down Expand Up @@ -90,6 +99,7 @@ internal object SignInChallengeCognitoActions : SignInChallengeActions {
SignInChallengeEvent.EventType.RetryVerifyChallengeAnswer(
answer,
metadata,
attributes,
challenge
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

package com.amplifyframework.statemachine.codegen.actions

import com.amplifyframework.auth.AuthUserAttribute
import com.amplifyframework.statemachine.Action
import com.amplifyframework.statemachine.codegen.data.AuthChallenge

internal interface SignInChallengeActions {
fun verifyChallengeAuthAction(
answer: String,
metadata: Map<String, String>,
userAttributes: List<AuthUserAttribute>,
challenge: AuthChallenge
): Action
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,24 @@

package com.amplifyframework.statemachine.codegen.events

import com.amplifyframework.auth.AuthUserAttribute
import com.amplifyframework.statemachine.StateMachineEvent
import com.amplifyframework.statemachine.codegen.data.AuthChallenge
import java.util.Date

internal class SignInChallengeEvent(val eventType: EventType, override val time: Date? = null) : StateMachineEvent {
sealed class EventType {
data class WaitForAnswer(val challenge: AuthChallenge, val hasNewResponse: Boolean = false) : EventType()
data class VerifyChallengeAnswer(val answer: String, val metadata: Map<String, String>) : EventType()
data class VerifyChallengeAnswer(
val answer: String,
val metadata: Map<String, String>,
val userAttributes: List<AuthUserAttribute>
) : EventType()

data class RetryVerifyChallengeAnswer(
val answer: String,
val metadata: Map<String, String>,
val userAttributes: List<AuthUserAttribute>,
val authChallenge: AuthChallenge
) : EventType()
data class FinalizeSignIn(val accessToken: String) : EventType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ internal sealed class SignInChallengeState : State {
is WaitingForAnswer -> when (challengeEvent) {
is SignInChallengeEvent.EventType.VerifyChallengeAnswer -> {
val action = challengeActions.verifyChallengeAuthAction(
challengeEvent.answer, challengeEvent.metadata, oldState.challenge
challengeEvent.answer,
challengeEvent.metadata,
challengeEvent.userAttributes,
oldState.challenge
)
StateResolution(Verifying(oldState.challenge.challengeName), listOf(action))
}
Expand All @@ -78,7 +81,10 @@ internal sealed class SignInChallengeState : State {
}
is SignInChallengeEvent.EventType.RetryVerifyChallengeAnswer -> {
val action = challengeActions.verifyChallengeAuthAction(
challengeEvent.answer, challengeEvent.metadata, challengeEvent.authChallenge
challengeEvent.answer,
challengeEvent.metadata,
challengeEvent.userAttributes,
challengeEvent.authChallenge,
)
StateResolution(Verifying(challengeEvent.authChallenge.challengeName), listOf(action))
}
Expand All @@ -92,7 +98,10 @@ internal sealed class SignInChallengeState : State {
when (challengeEvent) {
is SignInChallengeEvent.EventType.VerifyChallengeAnswer -> {
val action = challengeActions.verifyChallengeAuthAction(
challengeEvent.answer, challengeEvent.metadata, oldState.challenge
challengeEvent.answer,
challengeEvent.metadata,
challengeEvent.userAttributes,
oldState.challenge,
)
StateResolution(Verifying(oldState.challenge.challengeName), listOf(action))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ open class StateTransitionTestBase {

Mockito.`when`(
mockSignInChallengeActions.verifyChallengeAuthAction(
MockitoHelper.anyObject(),
MockitoHelper.anyObject(),
MockitoHelper.anyObject(),
MockitoHelper.anyObject()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,13 @@ class StateTransitionTests : StateTransitionTestBase() {
SignInChallengeEvent(
SignInChallengeEvent.EventType.RetryVerifyChallengeAnswer(
"test",
mapOf(),
emptyMap(),
emptyList(),
AuthChallenge(
ChallengeNameType.CustomChallenge.toString(),
"Test",
"session_mock_value",
mapOf()
emptyMap(),
)
)
)
Expand All @@ -401,7 +402,8 @@ class StateTransitionTests : StateTransitionTestBase() {
SignInChallengeEvent(
SignInChallengeEvent.EventType.VerifyChallengeAnswer(
"test",
mapOf()
emptyMap(),
emptyList()
)
)
)
Expand Down Expand Up @@ -481,7 +483,7 @@ class StateTransitionTests : StateTransitionTestBase() {
challengeState?.apply {
stateMachine.send(
SignInChallengeEvent(
SignInChallengeEvent.EventType.VerifyChallengeAnswer("test", mapOf())
SignInChallengeEvent.EventType.VerifyChallengeAnswer("test", emptyMap(), emptyList())
)
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amplifyframework.auth.cognito.actions

import androidx.test.core.app.ApplicationProvider
import aws.sdk.kotlin.services.cognitoidentityprovider.CognitoIdentityProviderClient
import aws.sdk.kotlin.services.cognitoidentityprovider.model.RespondToAuthChallengeRequest
import com.amplifyframework.auth.AuthUserAttribute
import com.amplifyframework.auth.AuthUserAttributeKey
import com.amplifyframework.auth.cognito.AWSCognitoAuthService
import com.amplifyframework.auth.cognito.AuthEnvironment
import com.amplifyframework.auth.cognito.StoreClientBehavior
import com.amplifyframework.logging.Logger
import com.amplifyframework.statemachine.EventDispatcher
import com.amplifyframework.statemachine.StateMachineEvent
import com.amplifyframework.statemachine.codegen.data.AmplifyCredential
import com.amplifyframework.statemachine.codegen.data.AuthChallenge
import com.amplifyframework.statemachine.codegen.data.AuthConfiguration
import com.amplifyframework.statemachine.codegen.data.CredentialType
import com.amplifyframework.statemachine.codegen.data.UserPoolConfiguration
import io.mockk.coEvery
import io.mockk.every
import io.mockk.mockk
import io.mockk.slot
import junit.framework.TestCase.assertTrue
import kotlin.test.assertEquals
import kotlinx.coroutines.test.runTest
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner

@RunWith(RobolectricTestRunner::class)
class SignInChallengeCognitoActionsTest {

private val pool = mockk<UserPoolConfiguration> {
every { appClient } returns "client"
every { appClientSecret } returns null
every { pinpointAppId } returns null
}
private val configuration = mockk<AuthConfiguration> {
every { userPool } returns pool
}
private val cognitoAuthService = mockk<AWSCognitoAuthService>()
private val credentialStoreClient = mockk<StoreClientBehavior> {
coEvery { loadCredentials(CredentialType.ASF) } returns AmplifyCredential.ASFDevice("asf_id")
}
private val logger = mockk<Logger>(relaxed = true)
private val cognitoIdentityProviderClientMock = mockk<CognitoIdentityProviderClient>()

private val capturedEvent = slot<StateMachineEvent>()
private val dispatcher = mockk<EventDispatcher> {
every { send(capture(capturedEvent)) }.answers { }
}

private lateinit var authEnvironment: AuthEnvironment

@Before
fun setup() {
every { cognitoAuthService.cognitoIdentityProviderClient }.answers { cognitoIdentityProviderClientMock }
authEnvironment = AuthEnvironment(
ApplicationProvider.getApplicationContext(),
configuration,
cognitoAuthService,
credentialStoreClient,
null,
null,
logger
)
}

@Test
fun `very auth challenge without user attributes`() = runTest {
val expectedChallengeResponses = mapOf(
"USERNAME" to "testUser"
)
val capturedRequest = slot<RespondToAuthChallengeRequest>()
coEvery {
cognitoIdentityProviderClientMock.respondToAuthChallenge(capture(capturedRequest))
}.answers {
mockk()
}

SignInChallengeCognitoActions.verifyChallengeAuthAction(
"myAnswer",
emptyMap(),
emptyList(),
AuthChallenge(
"CONFIRM_SIGN_IN_WITH_NEW_PASSWORD",
username = "testUser",
session = null,
parameters = null
)
).execute(dispatcher, authEnvironment)

assertTrue(capturedRequest.isCaptured)
assertEquals(expectedChallengeResponses, capturedRequest.captured.challengeResponses)
}

@Test
fun `user attributes are added to auth challenge`() = runTest {
val providedUserAttributes = listOf(AuthUserAttribute(AuthUserAttributeKey.phoneNumber(), "+15555555555"))
val expectedChallengeResponses = mapOf(
"USERNAME" to "testUser",
"userAttributes.phone_number" to "+15555555555"
)
val capturedRequest = slot<RespondToAuthChallengeRequest>()
coEvery {
cognitoIdentityProviderClientMock.respondToAuthChallenge(capture(capturedRequest))
}.answers {
mockk()
}

SignInChallengeCognitoActions.verifyChallengeAuthAction(
"myAnswer",
emptyMap(),
providedUserAttributes,
AuthChallenge(
"CONFIRM_SIGN_IN_WITH_NEW_PASSWORD",
username = "testUser",
session = null,
parameters = null
)
).execute(dispatcher, authEnvironment)

assertTrue(capturedRequest.isCaptured)
assertEquals(expectedChallengeResponses, capturedRequest.captured.challengeResponses)
}
}
Loading