diff --git a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt index bb9323477..4eedf4ebf 100644 --- a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt +++ b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt @@ -97,7 +97,7 @@ internal class LivenessWebSocket( @VisibleForTesting internal var webSocket: WebSocket? = null internal val challengeId = UUID.randomUUID().toString() - lateinit var challengeType: FaceLivenessChallengeType + var challengeType: FaceLivenessChallengeType? = null private var initialDetectedFace: BoundingBox? = null private var faceDetectedStart = 0L private var videoStartTimestamp = 0L @@ -139,16 +139,6 @@ internal class LivenessWebSocket( } } - init { - // If the client session is requesting a 1.0.0 FaceMovementAndLight challenge, the backend won't return a - // ChallengeEvent so we have to set the challengeType manually - clientSessionInformation.challengeVersions.forEach { - if (it.compareType(Challenge.FaceMovementAndLightChallenge("1.0.0"))) { - challengeType = FaceLivenessChallengeType.FaceMovementAndLightChallenge - } - } - } - override fun onMessage(webSocket: WebSocket, text: String) { LOG.debug("WebSocket onMessage text") super.onMessage(webSocket, text) @@ -162,9 +152,20 @@ internal class LivenessWebSocket( if (response.challengeEvent != null) { challengeType = response.challengeEvent.challengeType } else if (response.serverSessionInformationEvent != null) { + + val clientRequestedOldLightChallenge = clientSessionInformation.challengeVersions + .any { it == Challenge.FaceMovementAndLightChallenge("1.0.0") } + + if (challengeType == null && clientRequestedOldLightChallenge) { + // For the 1.0.0 version of FaceMovementAndLight challenge, backend doesn't send a + // ChallengeEvent so we need to manually check and set it if that specific challenge + // was requested. + challengeType = FaceLivenessChallengeType.FaceMovementAndLightChallenge + } + // If challengeType hasn't been initialized by this point it's because server sent an // unsupported challenge type so return an error to the client. - if (!this@LivenessWebSocket::challengeType.isInitialized) { + if (challengeType == null) { onErrorReceived.accept( PredictionsException( "Received an unsupported ChallengeType from the backend", @@ -392,44 +393,63 @@ internal class LivenessWebSocket( width = initialFaceRect.width() / clientSessionInformation.videoWidth ) faceDetectedStart = adjustedDate(videoStartTime) - val clientInfoEvent = - ClientSessionInformationEvent( + + if (challengeType == null) { + onErrorReceived.accept( + PredictionsException( + "Failed to send an initial face detected event", + AmplifyException.TODO_RECOVERY_SUGGESTION + ) + ) + } else { + val clientInfoEvent = + ClientSessionInformationEvent( + challenge = buildClientChallenge( + challengeType = challengeType!!, + challengeId = challengeId, + initialFace = InitialFace( + boundingBox = initialDetectedFace!!, + initialFaceDetectedTimestamp = faceDetectedStart + ), + videoStartTimestamp = videoStartTimestamp + ) + ) + sendClientInfoEvent(clientInfoEvent) + } + } + + fun sendFinalEvent(targetFaceRect: RectF, faceMatchedStart: Long, faceMatchedEnd: Long) { + if (challengeType == null) { + onErrorReceived.accept( + PredictionsException( + "Failed to send an initial face detected event", + AmplifyException.TODO_RECOVERY_SUGGESTION + ) + ) + } else { + val finalClientInfoEvent = ClientSessionInformationEvent( challenge = buildClientChallenge( - challengeType = challengeType, + challengeType = challengeType!!, challengeId = challengeId, + videoEndTimestamp = videoEndTimestamp, initialFace = InitialFace( boundingBox = initialDetectedFace!!, initialFaceDetectedTimestamp = faceDetectedStart ), - videoStartTimestamp = videoStartTimestamp - ) - ) - sendClientInfoEvent(clientInfoEvent) - } - - fun sendFinalEvent(targetFaceRect: RectF, faceMatchedStart: Long, faceMatchedEnd: Long) { - val finalClientInfoEvent = ClientSessionInformationEvent( - challenge = buildClientChallenge( - challengeType = challengeType, - challengeId = challengeId, - videoEndTimestamp = videoEndTimestamp, - initialFace = InitialFace( - boundingBox = initialDetectedFace!!, - initialFaceDetectedTimestamp = faceDetectedStart - ), - targetFace = TargetFace( - faceDetectedInTargetPositionStartTimestamp = adjustedDate(faceMatchedStart), - faceDetectedInTargetPositionEndTimestamp = adjustedDate(faceMatchedEnd), - boundingBox = BoundingBox( - left = targetFaceRect.left / clientSessionInformation.videoWidth, - top = targetFaceRect.top / clientSessionInformation.videoHeight, - height = targetFaceRect.height() / clientSessionInformation.videoHeight, - width = targetFaceRect.width() / clientSessionInformation.videoWidth + targetFace = TargetFace( + faceDetectedInTargetPositionStartTimestamp = adjustedDate(faceMatchedStart), + faceDetectedInTargetPositionEndTimestamp = adjustedDate(faceMatchedEnd), + boundingBox = BoundingBox( + left = targetFaceRect.left / clientSessionInformation.videoWidth, + top = targetFaceRect.top / clientSessionInformation.videoHeight, + height = targetFaceRect.height() / clientSessionInformation.videoHeight, + width = targetFaceRect.width() / clientSessionInformation.videoWidth + ) ) ) ) - ) - sendClientInfoEvent(finalClientInfoEvent) + sendClientInfoEvent(finalClientInfoEvent) + } } fun sendColorDisplayedEvent( diff --git a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/service/RunFaceLivenessSession.kt b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/service/RunFaceLivenessSession.kt index beb123edf..a5f3df265 100644 --- a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/service/RunFaceLivenessSession.kt +++ b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/service/RunFaceLivenessSession.kt @@ -128,7 +128,7 @@ internal class RunFaceLivenessSession( private fun getChallengeId(): String = livenessWebSocket.challengeId - private fun getChallengeType(): FaceLivenessChallengeType = livenessWebSocket.challengeType + private fun getChallengeType(): FaceLivenessChallengeType = livenessWebSocket.challengeType!! private fun getFaceTargetChallenge( ovalParameters: OvalParameters, diff --git a/aws-predictions/src/test/java/com/amplifyframework/predictions/aws/http/LivenessWebSocketTest.kt b/aws-predictions/src/test/java/com/amplifyframework/predictions/aws/http/LivenessWebSocketTest.kt index b41ae8aee..8d9ba7f92 100644 --- a/aws-predictions/src/test/java/com/amplifyframework/predictions/aws/http/LivenessWebSocketTest.kt +++ b/aws-predictions/src/test/java/com/amplifyframework/predictions/aws/http/LivenessWebSocketTest.kt @@ -222,8 +222,6 @@ internal class LivenessWebSocketTest { "3" ) val livenessWebSocket = createLivenessWebSocket(sessionInformation = sessionInfo) - assertEquals(FaceLivenessChallengeType.FaceMovementAndLightChallenge, livenessWebSocket.challengeType) - val event = ChallengeEvent( challengeType = FaceLivenessChallengeType.FaceMovementAndLightChallenge, version = "1.0.0" diff --git a/core/src/main/java/com/amplifyframework/predictions/models/FaceLivenessSessionInformation.kt b/core/src/main/java/com/amplifyframework/predictions/models/FaceLivenessSessionInformation.kt index ad408bfe3..8bc28bc06 100644 --- a/core/src/main/java/com/amplifyframework/predictions/models/FaceLivenessSessionInformation.kt +++ b/core/src/main/java/com/amplifyframework/predictions/models/FaceLivenessSessionInformation.kt @@ -58,17 +58,14 @@ class FaceLivenessSessionInformation { } @InternalAmplifyApi -sealed class Challenge private constructor(val name: String, val version: String) { +sealed class Challenge private constructor(val name: String, open val version: String) { @InternalAmplifyApi - class FaceMovementChallenge(version: String) : Challenge("FaceMovementChallenge", version) + data class FaceMovementChallenge(override val version: String) : Challenge("FaceMovementChallenge", version) @InternalAmplifyApi - class FaceMovementAndLightChallenge(version: String) : Challenge("FaceMovementAndLightChallenge", version) - - fun compareType(challenge: Challenge): Boolean { - return this.name == challenge.name && this.version == challenge.version - } + data class FaceMovementAndLightChallenge(override val version: String) : + Challenge("FaceMovementAndLightChallenge", version) fun toQueryParamString(): String = "${name}_$version" }