Skip to content

Commit

Permalink
Remove usages of !! and properly close the websocket when an unsuppor…
Browse files Browse the repository at this point in the history
…ted challenge type is found
  • Loading branch information
vincetran committed Jul 1, 2024
1 parent b51112b commit d776d4b
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amplifyframework.predictions.aws.exceptions

import com.amplifyframework.annotations.InternalAmplifyApi
import com.amplifyframework.predictions.PredictionsException

@InternalAmplifyApi
class FaceLivenessUnsupportedChallengeTypeException internal constructor(
message: String = "Received an unsupported ChallengeType from the backend.",
cause: Throwable? = null,
recoverySuggestion: String = "Verify that the Challenges configured in your backend are supported by the " +
"frontend code (e.g. Amplify UI)"
) : PredictionsException(message, cause, recoverySuggestion)
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.amplifyframework.predictions.PredictionsException
import com.amplifyframework.predictions.aws.BuildConfig
import com.amplifyframework.predictions.aws.exceptions.AccessDeniedException
import com.amplifyframework.predictions.aws.exceptions.FaceLivenessSessionNotFoundException
import com.amplifyframework.predictions.aws.exceptions.FaceLivenessUnsupportedChallengeTypeException
import com.amplifyframework.predictions.aws.models.liveness.BoundingBox
import com.amplifyframework.predictions.aws.models.liveness.ClientChallenge
import com.amplifyframework.predictions.aws.models.liveness.ClientSessionInformationEvent
Expand Down Expand Up @@ -78,10 +79,14 @@ internal class LivenessWebSocket(
val region: String,
val clientSessionInformation: FaceLivenessSessionInformation,
val livenessVersion: String?,
val onSessionInformationReceived: Consumer<SessionInformation>,
val onSessionResponseReceived: Consumer<SessionResponse>,
val onErrorReceived: Consumer<PredictionsException>,
val onComplete: Action
) {
internal data class SessionResponse(
val faceLivenessSession: SessionInformation,
val livenessChallengeType: FaceLivenessChallengeType
)

private val signer = AWSV4Signer()
private var credentials: Credentials? = null
Expand Down Expand Up @@ -165,17 +170,17 @@ internal class LivenessWebSocket(

// If challengeType hasn't been initialized by this point it's because server sent an
// unsupported challenge type so return an error to the client.
if (challengeType == null) {
onErrorReceived.accept(
PredictionsException(
"Received an unsupported ChallengeType from the backend",
"Verify that the Challenges configured in your backend are supported by the " +
"frontend code (e.g. Amplify UI)"
)
)
val resolvedChallengeType = challengeType
if (resolvedChallengeType == null) {
webSocketError = FaceLivenessUnsupportedChallengeTypeException()
destroy(UNSUPPORTED_CHALLENGE_CLOSURE_STATUS_CODE)
} else {
onSessionInformationReceived.accept(
response.serverSessionInformationEvent.sessionInformation
// send non-null challenge Type
onSessionResponseReceived.accept(
SessionResponse(
response.serverSessionInformationEvent.sessionInformation,
resolvedChallengeType
)
)
}
} else if (response.disconnectionEvent != null) {
Expand Down Expand Up @@ -394,7 +399,8 @@ internal class LivenessWebSocket(
)
faceDetectedStart = adjustedDate(videoStartTime)

if (challengeType == null) {
val resolvedChallengeType = challengeType
if (resolvedChallengeType == null) {
onErrorReceived.accept(
PredictionsException(
"Failed to send an initial face detected event",
Expand All @@ -405,7 +411,7 @@ internal class LivenessWebSocket(
val clientInfoEvent =
ClientSessionInformationEvent(
challenge = buildClientChallenge(
challengeType = challengeType!!,
challengeType = resolvedChallengeType,
challengeId = challengeId,
initialFace = InitialFace(
boundingBox = initialDetectedFace!!,
Expand All @@ -419,7 +425,8 @@ internal class LivenessWebSocket(
}

fun sendFinalEvent(targetFaceRect: RectF, faceMatchedStart: Long, faceMatchedEnd: Long) {
if (challengeType == null) {
val resolvedChallengeType = challengeType
if (resolvedChallengeType == null) {
onErrorReceived.accept(
PredictionsException(
"Failed to send an initial face detected event",
Expand All @@ -429,7 +436,7 @@ internal class LivenessWebSocket(
} else {
val finalClientInfoEvent = ClientSessionInformationEvent(
challenge = buildClientChallenge(
challengeType = challengeType!!,
challengeType = resolvedChallengeType,
challengeId = challengeId,
videoEndTimestamp = videoEndTimestamp,
initialFace = InitialFace(
Expand Down Expand Up @@ -607,6 +614,8 @@ internal class LivenessWebSocket(

companion object {
private const val NORMAL_SOCKET_CLOSURE_STATUS_CODE = 1000
// This is the same as the client-provided 'runtime error' status code
private const val UNSUPPORTED_CHALLENGE_CLOSURE_STATUS_CODE = 4005
private const val FOUR_MINUTES = 1000 * 60 * 4
@VisibleForTesting val datePattern = "EEE, d MMM yyyy HH:mm:ss z"
private val LOG = Amplify.Logging.logger(CategoryType.PREDICTIONS, "amplify:aws-predictions")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import com.amplifyframework.predictions.aws.models.liveness.FreshnessColor
import com.amplifyframework.predictions.aws.models.liveness.OvalParameters
import com.amplifyframework.predictions.aws.models.liveness.SessionInformation
import com.amplifyframework.predictions.models.ChallengeResponseEvent
import com.amplifyframework.predictions.models.FaceLivenessChallengeType
import com.amplifyframework.predictions.models.FaceLivenessSession
import com.amplifyframework.predictions.models.FaceLivenessSessionChallenge
import com.amplifyframework.predictions.models.FaceLivenessSessionInformation
Expand All @@ -58,9 +57,9 @@ internal class RunFaceLivenessSession(
region = clientSessionInformation.region,
clientSessionInformation = clientSessionInformation,
livenessVersion = livenessVersion,
onSessionInformationReceived = { serverSessionInformation ->
val challenges = processSessionInformation(serverSessionInformation)
val challengeType = getChallengeType()
onSessionResponseReceived = { serverSessionResponse ->
val challenges = processSessionInformation(serverSessionResponse.faceLivenessSession)
val challengeType = serverSessionResponse.livenessChallengeType
val faceLivenessSession = FaceLivenessSession(
challengeId = getChallengeId(),
challengeType = challengeType,
Expand Down Expand Up @@ -128,8 +127,6 @@ internal class RunFaceLivenessSession(

private fun getChallengeId(): String = livenessWebSocket.challengeId

private fun getChallengeType(): FaceLivenessChallengeType = livenessWebSocket.challengeType!!

private fun getFaceTargetChallenge(
ovalParameters: OvalParameters,
challengeConfig: ChallengeConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.amplifyframework.core.Action
import com.amplifyframework.core.BuildConfig
import com.amplifyframework.core.Consumer
import com.amplifyframework.predictions.PredictionsException
import com.amplifyframework.predictions.aws.exceptions.FaceLivenessUnsupportedChallengeTypeException
import com.amplifyframework.predictions.aws.models.liveness.ChallengeConfig
import com.amplifyframework.predictions.aws.models.liveness.ChallengeEvent
import com.amplifyframework.predictions.aws.models.liveness.ColorSequence
Expand Down Expand Up @@ -86,7 +87,7 @@ internal class LivenessWebSocketTest {
private lateinit var server: MockWebServer

private val onComplete = mockk<Action>(relaxed = true)
private val onSessionInformationReceived = mockk<Consumer<SessionInformation>>(relaxed = true)
private val onSessionResponseReceived = mockk<Consumer<LivenessWebSocket.SessionResponse>>(relaxed = true)
private val onErrorReceived = mockk<Consumer<PredictionsException>>(relaxed = true)
private val credentialsProvider = object : CredentialsProvider {
override suspend fun resolve(attributes: Attributes): Credentials {
Expand All @@ -100,7 +101,7 @@ internal class LivenessWebSocketTest {
}
}

private val defaultSessionInformation = createSessionInformation(
private val defaultSessionInformation = createClientSessionInformation(
listOf(Challenge.FaceMovementChallenge("1.0.0"))
)

Expand Down Expand Up @@ -213,6 +214,67 @@ internal class LivenessWebSocketTest {
assertEquals("AWS4-HMAC-SHA256", originalRequest.url.queryParameter("X-Amz-Algorithm"))
}

@Test
fun `test unsupported challengetype`() {
val clientSessionInfo = createClientSessionInformation(
listOf(Challenge.FaceMovementAndLightChallenge("2.0.0"))
)
val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = clientSessionInfo)
val unknownEvent = "{\"Type\":\"NewChallengeType\",\"Version\":\"1.0.0\"}"

val challengeEventHeaders = mapOf(
":event-type" to "ChallengeEvent",
":content-type" to "application/json",
":message-type" to "event"
)

val encodedChallengeTypeByteString =
LivenessEventStream.encode(unknownEvent.toByteArray(), challengeEventHeaders).array().toByteString()

livenessWebSocket.webSocketListener.onMessage(mockk(), encodedChallengeTypeByteString)

assertEquals(null, livenessWebSocket.challengeType)

val event = ServerSessionInformationEvent(
sessionInformation = SessionInformation(
challenge = ServerChallenge(
faceMovementAndLightChallenge = FaceMovementAndLightServerChallenge(
ovalParameters = OvalParameters(1.0f, 2.0f, .5f, .7f),
lightChallengeType = LightChallengeType.SEQUENTIAL,
challengeConfig = ChallengeConfig(
1.0f,
1.1f,
1.2f,
1.3f,
1.4f,
1.5f,
1.6f,
1.7f,
1.8f,
1.9f,
10
),
colorSequences = listOf(
ColorSequence(FreshnessColor(listOf(0, 1, 2)), 4.0f, 5.0f)
)
)
)
)
)

val headers = mapOf(
":event-type" to "ServerSessionInformationEvent",
":content-type" to "application/json",
":message-type" to "event"
)

val data = json.encodeToString(event)
val encodedByteString = LivenessEventStream.encode(data.toByteArray(), headers).array().toByteString()

livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString)
assertEquals(FaceLivenessUnsupportedChallengeTypeException(), livenessWebSocket.webSocketError)
}

@Test
fun `ensure challengetype is properly set when using the deprecated sessioninformation`() {
val sessionInfo = FaceLivenessSessionInformation(
Expand All @@ -221,7 +283,7 @@ internal class LivenessWebSocketTest {
"FaceMovementAndLightChallenge_1.0.0",
"3"
)
val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo)
val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo)
val event = ChallengeEvent(
challengeType = FaceLivenessChallengeType.FaceMovementAndLightChallenge,
version = "1.0.0"
Expand All @@ -243,8 +305,8 @@ internal class LivenessWebSocketTest {

@Test
fun `server facemovementandlight challenge event tracked`() {
val sessionInfo = createSessionInformation(listOf(Challenge.FaceMovementAndLightChallenge("2.0.0")))
val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo)
val sessionInfo = createClientSessionInformation(listOf(Challenge.FaceMovementAndLightChallenge("2.0.0")))
val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo)
val event = ChallengeEvent(
challengeType = FaceLivenessChallengeType.FaceMovementAndLightChallenge,
version = "2.0.0"
Expand All @@ -266,8 +328,8 @@ internal class LivenessWebSocketTest {

@Test
fun `server facemovement challenge event tracked`() {
val sessionInfo = createSessionInformation(listOf(Challenge.FaceMovementChallenge("1.0.0")))
val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo)
val sessionInfo = createClientSessionInformation(listOf(Challenge.FaceMovementChallenge("1.0.0")))
val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo)

val event = ChallengeEvent(
challengeType = FaceLivenessChallengeType.FaceMovementChallenge,
Expand All @@ -290,8 +352,8 @@ internal class LivenessWebSocketTest {

@Test
fun `server facemovementandlight session event tracked`() {
val sessionInfo = createSessionInformation(listOf(Challenge.FaceMovementAndLightChallenge("2.0.0")))
val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo)
val sessionInfo = createClientSessionInformation(listOf(Challenge.FaceMovementAndLightChallenge("2.0.0")))
val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo)

val challengeEvent = ChallengeEvent(
challengeType = FaceLivenessChallengeType.FaceMovementAndLightChallenge,
Expand Down Expand Up @@ -335,6 +397,10 @@ internal class LivenessWebSocketTest {
)
)
)
val sessionResponse = LivenessWebSocket.SessionResponse(
event.sessionInformation,
FaceLivenessChallengeType.FaceMovementAndLightChallenge
)
val headers = mapOf(
":event-type" to "ServerSessionInformationEvent",
":content-type" to "application/json",
Expand All @@ -346,13 +412,13 @@ internal class LivenessWebSocketTest {

livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString)

verify { onSessionInformationReceived.accept(event.sessionInformation) }
verify { onSessionResponseReceived.accept(sessionResponse) }
}

@Test
fun `server facemovement session event tracked`() {
val sessionInfo = createSessionInformation(listOf(Challenge.FaceMovementChallenge("1.0.0")))
val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo)
val sessionInfo = createClientSessionInformation(listOf(Challenge.FaceMovementChallenge("1.0.0")))
val livenessWebSocket = createLivenessWebSocket(clientSessionInformation = sessionInfo)

val challengeEvent = ChallengeEvent(
challengeType = FaceLivenessChallengeType.FaceMovementChallenge,
Expand Down Expand Up @@ -392,6 +458,10 @@ internal class LivenessWebSocketTest {
)
)
)
val sessionResponse = LivenessWebSocket.SessionResponse(
event.sessionInformation,
FaceLivenessChallengeType.FaceMovementChallenge
)
val sessionHeaders = mapOf(
":event-type" to "ServerSessionInformationEvent",
":content-type" to "application/json",
Expand All @@ -403,7 +473,7 @@ internal class LivenessWebSocketTest {

livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString)

verify { onSessionInformationReceived.accept(event.sessionInformation) }
verify { onSessionResponseReceived.accept(sessionResponse) }
}

@Test
Expand All @@ -423,7 +493,7 @@ internal class LivenessWebSocketTest {

livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString)

verify(exactly = 0) { onSessionInformationReceived.accept(any()) }
verify(exactly = 0) { onSessionResponseReceived.accept(any()) }
verify(exactly = 0) { onErrorReceived.accept(any()) }
verify(exactly = 0) { webSocket.close(any(), any()) }
}
Expand All @@ -445,7 +515,7 @@ internal class LivenessWebSocketTest {

livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString)

verify(exactly = 0) { onSessionInformationReceived.accept(any()) }
verify(exactly = 0) { onSessionResponseReceived.accept(any()) }
verify(exactly = 0) { onErrorReceived.accept(any()) }
verify(exactly = 1) { webSocket.close(any(), any()) }
}
Expand Down Expand Up @@ -607,7 +677,7 @@ internal class LivenessWebSocketTest {
fun `sendVideoEvent test`() {
}

private fun createSessionInformation(challengeVersions: List<Challenge>) = FaceLivenessSessionInformation(
private fun createClientSessionInformation(challengeVersions: List<Challenge>) = FaceLivenessSessionInformation(
videoWidth = 1f,
videoHeight = 1f,
region = "region",
Expand All @@ -618,14 +688,14 @@ internal class LivenessWebSocketTest {

private fun createLivenessWebSocket(
livenessVersion: String? = null,
sessionInformation: FaceLivenessSessionInformation? = null
clientSessionInformation: FaceLivenessSessionInformation? = null
) = LivenessWebSocket(
credentialsProvider,
server.url("/").toString(),
"",
sessionInformation ?: defaultSessionInformation,
clientSessionInformation ?: defaultSessionInformation,
livenessVersion,
onSessionInformationReceived,
onSessionResponseReceived,
onErrorReceived,
onComplete
)
Expand Down

0 comments on commit d776d4b

Please sign in to comment.