diff --git a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/exceptions/FaceLivenessUnsupportedChallengeTypeException.kt b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/exceptions/FaceLivenessUnsupportedChallengeTypeException.kt new file mode 100644 index 0000000000..0ecd19a1fe --- /dev/null +++ b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/exceptions/FaceLivenessUnsupportedChallengeTypeException.kt @@ -0,0 +1,26 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package com.amplifyframework.predictions.aws.exceptions + +import com.amplifyframework.annotations.InternalAmplifyApi +import com.amplifyframework.predictions.PredictionsException + +@InternalAmplifyApi +class FaceLivenessUnsupportedChallengeTypeException internal constructor( + message: String = "Received an unsupported ChallengeType from the backend.", + cause: Throwable? = null, + recoverySuggestion: String = "Verify that the Challenges configured in your backend are supported by the " + + "frontend code (e.g. Amplify UI)" +) : PredictionsException(message, cause, recoverySuggestion) diff --git a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt index 4eedf4ebf7..fd93e89f0e 100644 --- a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt +++ b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt @@ -30,6 +30,7 @@ import com.amplifyframework.predictions.PredictionsException import com.amplifyframework.predictions.aws.BuildConfig import com.amplifyframework.predictions.aws.exceptions.AccessDeniedException import com.amplifyframework.predictions.aws.exceptions.FaceLivenessSessionNotFoundException +import com.amplifyframework.predictions.aws.exceptions.FaceLivenessUnsupportedChallengeTypeException import com.amplifyframework.predictions.aws.models.liveness.BoundingBox import com.amplifyframework.predictions.aws.models.liveness.ClientChallenge import com.amplifyframework.predictions.aws.models.liveness.ClientSessionInformationEvent @@ -78,10 +79,14 @@ internal class LivenessWebSocket( val region: String, val clientSessionInformation: FaceLivenessSessionInformation, val livenessVersion: String?, - val onSessionInformationReceived: Consumer, + val onSessionResponseReceived: Consumer, val onErrorReceived: Consumer, val onComplete: Action ) { + internal data class SessionResponse( + val faceLivenessSession: SessionInformation, + val livenessChallengeType: FaceLivenessChallengeType + ) private val signer = AWSV4Signer() private var credentials: Credentials? = null @@ -165,17 +170,17 @@ internal class LivenessWebSocket( // If challengeType hasn't been initialized by this point it's because server sent an // unsupported challenge type so return an error to the client. - if (challengeType == null) { - onErrorReceived.accept( - PredictionsException( - "Received an unsupported ChallengeType from the backend", - "Verify that the Challenges configured in your backend are supported by the " + - "frontend code (e.g. Amplify UI)" - ) - ) + val resolvedChallengeType = challengeType + if (resolvedChallengeType == null) { + webSocketError = FaceLivenessUnsupportedChallengeTypeException() + destroy(UNSUPPORTED_CHALLENGE_CLOSURE_STATUS_CODE) } else { - onSessionInformationReceived.accept( - response.serverSessionInformationEvent.sessionInformation + // send non-null challenge Type + onSessionResponseReceived.accept( + SessionResponse( + response.serverSessionInformationEvent.sessionInformation, + resolvedChallengeType + ) ) } } else if (response.disconnectionEvent != null) { @@ -394,7 +399,8 @@ internal class LivenessWebSocket( ) faceDetectedStart = adjustedDate(videoStartTime) - if (challengeType == null) { + val resolvedChallengeType = challengeType + if (resolvedChallengeType == null) { onErrorReceived.accept( PredictionsException( "Failed to send an initial face detected event", @@ -405,7 +411,7 @@ internal class LivenessWebSocket( val clientInfoEvent = ClientSessionInformationEvent( challenge = buildClientChallenge( - challengeType = challengeType!!, + challengeType = resolvedChallengeType, challengeId = challengeId, initialFace = InitialFace( boundingBox = initialDetectedFace!!, @@ -419,7 +425,8 @@ internal class LivenessWebSocket( } fun sendFinalEvent(targetFaceRect: RectF, faceMatchedStart: Long, faceMatchedEnd: Long) { - if (challengeType == null) { + val resolvedChallengeType = challengeType + if (resolvedChallengeType == null) { onErrorReceived.accept( PredictionsException( "Failed to send an initial face detected event", @@ -429,7 +436,7 @@ internal class LivenessWebSocket( } else { val finalClientInfoEvent = ClientSessionInformationEvent( challenge = buildClientChallenge( - challengeType = challengeType!!, + challengeType = resolvedChallengeType, challengeId = challengeId, videoEndTimestamp = videoEndTimestamp, initialFace = InitialFace( @@ -607,6 +614,8 @@ internal class LivenessWebSocket( companion object { private const val NORMAL_SOCKET_CLOSURE_STATUS_CODE = 1000 + // This is the same as the client-provided 'runtime error' status code + private const val UNSUPPORTED_CHALLENGE_CLOSURE_STATUS_CODE = 4005 private const val FOUR_MINUTES = 1000 * 60 * 4 @VisibleForTesting val datePattern = "EEE, d MMM yyyy HH:mm:ss z" private val LOG = Amplify.Logging.logger(CategoryType.PREDICTIONS, "amplify:aws-predictions") diff --git a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/service/RunFaceLivenessSession.kt b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/service/RunFaceLivenessSession.kt index a5f3df2652..f515248846 100644 --- a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/service/RunFaceLivenessSession.kt +++ b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/service/RunFaceLivenessSession.kt @@ -36,7 +36,6 @@ import com.amplifyframework.predictions.aws.models.liveness.FreshnessColor import com.amplifyframework.predictions.aws.models.liveness.OvalParameters import com.amplifyframework.predictions.aws.models.liveness.SessionInformation import com.amplifyframework.predictions.models.ChallengeResponseEvent -import com.amplifyframework.predictions.models.FaceLivenessChallengeType import com.amplifyframework.predictions.models.FaceLivenessSession import com.amplifyframework.predictions.models.FaceLivenessSessionChallenge import com.amplifyframework.predictions.models.FaceLivenessSessionInformation @@ -58,9 +57,9 @@ internal class RunFaceLivenessSession( region = clientSessionInformation.region, clientSessionInformation = clientSessionInformation, livenessVersion = livenessVersion, - onSessionInformationReceived = { serverSessionInformation -> - val challenges = processSessionInformation(serverSessionInformation) - val challengeType = getChallengeType() + onSessionResponseReceived = { serverSessionResponse -> + val challenges = processSessionInformation(serverSessionResponse.faceLivenessSession) + val challengeType = serverSessionResponse.livenessChallengeType val faceLivenessSession = FaceLivenessSession( challengeId = getChallengeId(), challengeType = challengeType, @@ -128,8 +127,6 @@ internal class RunFaceLivenessSession( private fun getChallengeId(): String = livenessWebSocket.challengeId - private fun getChallengeType(): FaceLivenessChallengeType = livenessWebSocket.challengeType!! - private fun getFaceTargetChallenge( ovalParameters: OvalParameters, challengeConfig: ChallengeConfig diff --git a/aws-predictions/src/test/java/com/amplifyframework/predictions/aws/http/LivenessWebSocketTest.kt b/aws-predictions/src/test/java/com/amplifyframework/predictions/aws/http/LivenessWebSocketTest.kt index 8d9ba7f92c..8dec41dcb7 100644 --- a/aws-predictions/src/test/java/com/amplifyframework/predictions/aws/http/LivenessWebSocketTest.kt +++ b/aws-predictions/src/test/java/com/amplifyframework/predictions/aws/http/LivenessWebSocketTest.kt @@ -23,6 +23,7 @@ import com.amplifyframework.core.Action import com.amplifyframework.core.BuildConfig import com.amplifyframework.core.Consumer import com.amplifyframework.predictions.PredictionsException +import com.amplifyframework.predictions.aws.exceptions.FaceLivenessUnsupportedChallengeTypeException import com.amplifyframework.predictions.aws.models.liveness.ChallengeConfig import com.amplifyframework.predictions.aws.models.liveness.ChallengeEvent import com.amplifyframework.predictions.aws.models.liveness.ColorSequence @@ -86,7 +87,7 @@ internal class LivenessWebSocketTest { private lateinit var server: MockWebServer private val onComplete = mockk(relaxed = true) - private val onSessionInformationReceived = mockk>(relaxed = true) + private val onSessionResponseReceived = mockk>(relaxed = true) private val onErrorReceived = mockk>(relaxed = true) private val credentialsProvider = object : CredentialsProvider { override suspend fun resolve(attributes: Attributes): Credentials { @@ -100,7 +101,7 @@ internal class LivenessWebSocketTest { } } - private val defaultSessionInformation = createSessionInformation( + private val defaultSessionInformation = createClientSessionInformation( listOf(Challenge.FaceMovementChallenge("1.0.0")) ) @@ -213,6 +214,67 @@ internal class LivenessWebSocketTest { assertEquals("AWS4-HMAC-SHA256", originalRequest.url.queryParameter("X-Amz-Algorithm")) } + @Test + fun `test unsupported challengetype`() { + val clientSessionInfo = createClientSessionInformation( + listOf(Challenge.FaceMovementAndLightChallenge("2.0.0")) + ) + val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = clientSessionInfo) + val unknownEvent = "{\"Type\":\"NewChallengeType\",\"Version\":\"1.0.0\"}" + + val challengeEventHeaders = mapOf( + ":event-type" to "ChallengeEvent", + ":content-type" to "application/json", + ":message-type" to "event" + ) + + val encodedChallengeTypeByteString = + LivenessEventStream.encode(unknownEvent.toByteArray(), challengeEventHeaders).array().toByteString() + + livenessWebSocket.webSocketListener.onMessage(mockk(), encodedChallengeTypeByteString) + + assertEquals(null, livenessWebSocket.challengeType) + + val event = ServerSessionInformationEvent( + sessionInformation = SessionInformation( + challenge = ServerChallenge( + faceMovementAndLightChallenge = FaceMovementAndLightServerChallenge( + ovalParameters = OvalParameters(1.0f, 2.0f, .5f, .7f), + lightChallengeType = LightChallengeType.SEQUENTIAL, + challengeConfig = ChallengeConfig( + 1.0f, + 1.1f, + 1.2f, + 1.3f, + 1.4f, + 1.5f, + 1.6f, + 1.7f, + 1.8f, + 1.9f, + 10 + ), + colorSequences = listOf( + ColorSequence(FreshnessColor(listOf(0, 1, 2)), 4.0f, 5.0f) + ) + ) + ) + ) + ) + + val headers = mapOf( + ":event-type" to "ServerSessionInformationEvent", + ":content-type" to "application/json", + ":message-type" to "event" + ) + + val data = json.encodeToString(event) + val encodedByteString = LivenessEventStream.encode(data.toByteArray(), headers).array().toByteString() + + livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString) + assertEquals(FaceLivenessUnsupportedChallengeTypeException(), livenessWebSocket.webSocketError) + } + @Test fun `ensure challengetype is properly set when using the deprecated sessioninformation`() { val sessionInfo = FaceLivenessSessionInformation( @@ -221,7 +283,7 @@ internal class LivenessWebSocketTest { "FaceMovementAndLightChallenge_1.0.0", "3" ) - val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo) + val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo) val event = ChallengeEvent( challengeType = FaceLivenessChallengeType.FaceMovementAndLightChallenge, version = "1.0.0" @@ -243,8 +305,8 @@ internal class LivenessWebSocketTest { @Test fun `server facemovementandlight challenge event tracked`() { - val sessionInfo = createSessionInformation(listOf(Challenge.FaceMovementAndLightChallenge("2.0.0"))) - val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo) + val sessionInfo = createClientSessionInformation(listOf(Challenge.FaceMovementAndLightChallenge("2.0.0"))) + val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo) val event = ChallengeEvent( challengeType = FaceLivenessChallengeType.FaceMovementAndLightChallenge, version = "2.0.0" @@ -266,8 +328,8 @@ internal class LivenessWebSocketTest { @Test fun `server facemovement challenge event tracked`() { - val sessionInfo = createSessionInformation(listOf(Challenge.FaceMovementChallenge("1.0.0"))) - val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo) + val sessionInfo = createClientSessionInformation(listOf(Challenge.FaceMovementChallenge("1.0.0"))) + val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo) val event = ChallengeEvent( challengeType = FaceLivenessChallengeType.FaceMovementChallenge, @@ -290,8 +352,8 @@ internal class LivenessWebSocketTest { @Test fun `server facemovementandlight session event tracked`() { - val sessionInfo = createSessionInformation(listOf(Challenge.FaceMovementAndLightChallenge("2.0.0"))) - val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo) + val sessionInfo = createClientSessionInformation(listOf(Challenge.FaceMovementAndLightChallenge("2.0.0"))) + val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo) val challengeEvent = ChallengeEvent( challengeType = FaceLivenessChallengeType.FaceMovementAndLightChallenge, @@ -335,6 +397,10 @@ internal class LivenessWebSocketTest { ) ) ) + val sessionResponse = LivenessWebSocket.SessionResponse( + event.sessionInformation, + FaceLivenessChallengeType.FaceMovementAndLightChallenge + ) val headers = mapOf( ":event-type" to "ServerSessionInformationEvent", ":content-type" to "application/json", @@ -346,13 +412,13 @@ internal class LivenessWebSocketTest { livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString) - verify { onSessionInformationReceived.accept(event.sessionInformation) } + verify { onSessionResponseReceived.accept(sessionResponse) } } @Test fun `server facemovement session event tracked`() { - val sessionInfo = createSessionInformation(listOf(Challenge.FaceMovementChallenge("1.0.0"))) - val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo) + val sessionInfo = createClientSessionInformation(listOf(Challenge.FaceMovementChallenge("1.0.0"))) + val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo) val challengeEvent = ChallengeEvent( challengeType = FaceLivenessChallengeType.FaceMovementChallenge, @@ -392,6 +458,10 @@ internal class LivenessWebSocketTest { ) ) ) + val sessionResponse = LivenessWebSocket.SessionResponse( + event.sessionInformation, + FaceLivenessChallengeType.FaceMovementChallenge + ) val sessionHeaders = mapOf( ":event-type" to "ServerSessionInformationEvent", ":content-type" to "application/json", @@ -403,7 +473,7 @@ internal class LivenessWebSocketTest { livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString) - verify { onSessionInformationReceived.accept(event.sessionInformation) } + verify { onSessionResponseReceived.accept(sessionResponse) } } @Test @@ -423,7 +493,7 @@ internal class LivenessWebSocketTest { livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString) - verify(exactly = 0) { onSessionInformationReceived.accept(any()) } + verify(exactly = 0) { onSessionResponseReceived.accept(any()) } verify(exactly = 0) { onErrorReceived.accept(any()) } verify(exactly = 0) { webSocket.close(any(), any()) } } @@ -445,7 +515,7 @@ internal class LivenessWebSocketTest { livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString) - verify(exactly = 0) { onSessionInformationReceived.accept(any()) } + verify(exactly = 0) { onSessionResponseReceived.accept(any()) } verify(exactly = 0) { onErrorReceived.accept(any()) } verify(exactly = 1) { webSocket.close(any(), any()) } } @@ -607,7 +677,7 @@ internal class LivenessWebSocketTest { fun `sendVideoEvent test`() { } - private fun createSessionInformation(challengeVersions: List) = FaceLivenessSessionInformation( + private fun createClientSessionInformation(challengeVersions: List) = FaceLivenessSessionInformation( videoWidth = 1f, videoHeight = 1f, region = "region", @@ -618,14 +688,14 @@ internal class LivenessWebSocketTest { private fun createLivenessWebSocket( livenessVersion: String? = null, - sessionInformation: FaceLivenessSessionInformation? = null + clientSessionInformation: FaceLivenessSessionInformation? = null ) = LivenessWebSocket( credentialsProvider, server.url("/").toString(), "", - sessionInformation ?: defaultSessionInformation, + clientSessionInformation ?: defaultSessionInformation, livenessVersion, - onSessionInformationReceived, + onSessionResponseReceived, onErrorReceived, onComplete )