Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(liveness): Liveness web socket expiration retry #2615

Merged
merged 20 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ internal class AWSV4Signer {
timeFormatter.isLenient = false
}

// used in incorrect time flow where we send an invalid response first to get time offset
fun resetPriorSignature() {
priorSignature = ""
}

fun getSignedUri(
uri: URI,
credentials: Credentials,
Expand Down
tjleing marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ import com.amplifyframework.util.UserAgent
import java.net.URI
import java.net.URLDecoder
import java.nio.ByteBuffer
import java.text.SimpleDateFormat
import java.util.Date
import java.util.Locale
import java.util.UUID
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
Expand Down Expand Up @@ -76,6 +78,10 @@ internal class LivenessWebSocket(
private val signer = AWSV4Signer()
private var credentials: Credentials? = null

private val sdf = SimpleDateFormat("yyyymmddhhmmss", Locale.US)
tjleing marked this conversation as resolved.
Show resolved Hide resolved
internal var offset = 0L
tjleing marked this conversation as resolved.
Show resolved Hide resolved
internal var closeExpected = false
tjleing marked this conversation as resolved.
Show resolved Hide resolved

@VisibleForTesting
internal var webSocket: WebSocket? = null
internal val challengeId = UUID.randomUUID().toString()
Expand All @@ -86,13 +92,31 @@ internal class LivenessWebSocket(
private var webSocketError: PredictionsException? = null
internal var clientStoppedSession = false
val json = Json { ignoreUnknownKeys = true }
val FIVE_MINUTES = 1000 * 60 * 5
tjleing marked this conversation as resolved.
Show resolved Hide resolved
val datePattern = "EEE, d MMM yyyy HH:mm:ss z"

@VisibleForTesting
internal var webSocketListener = object : WebSocketListener() {
override fun onOpen(webSocket: WebSocket, response: Response) {
LOG.debug("WebSocket onOpen")
super.onOpen(webSocket, response)
[email protected] = webSocket

// device time may be set incorrectly; read the header to skew time and retry
val sdf = SimpleDateFormat(datePattern, Locale.US)
tjleing marked this conversation as resolved.
Show resolved Hide resolved
val date = response.header("Date")?.let { sdf.parse(it) }
val tempOffset = if (date != null) {
date.time - (Date().time + offset)
} else 0

// if offset is > 5 minutes, server will reject the request
if (kotlin.math.abs(tempOffset) < FIVE_MINUTES) {
gpanshu marked this conversation as resolved.
Show resolved Hide resolved
super.onOpen(webSocket, response)
[email protected] = webSocket
} else {
// server will close this websocket, don't report that failure back
closeExpected = true
offset = tempOffset
start()
ankpshah marked this conversation as resolved.
Show resolved Hide resolved
}
}

override fun onMessage(webSocket: WebSocket, text: String) {
Expand Down Expand Up @@ -130,7 +154,9 @@ internal class LivenessWebSocket(
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
LOG.debug("WebSocket onClosed")
super.onClosed(webSocket, code, reason)
if (code != 1000 && !clientStoppedSession) {
tjleing marked this conversation as resolved.
Show resolved Hide resolved
if (closeExpected) {
// do nothing; we expected the server to close the connection
} else if (code != 1000 && !clientStoppedSession) {
val faceLivenessException = webSocketError ?: PredictionsException(
"An error occurred during the face liveness check.",
reason
Expand Down Expand Up @@ -173,7 +199,10 @@ internal class LivenessWebSocket(
try {
val credentials = credentialsProvider.resolve(emptyAttributes())
[email protected] = credentials
val signedUri = signer.getSignedUri(URI.create(endpoint), credentials, region, userAgent)
signer.resetPriorSignature()
val signedUri = signer.getSignedUri(
URI.create(endpoint), credentials, region, userAgent, Date().time + offset
)
if (signedUri != null) {
val signedEndpoint = URLDecoder.decode(signedUri.toString(), "UTF-8")
val signedEndpointNoSpaces = signedEndpoint.replace(" ", signer.encodedSpace)
Expand Down Expand Up @@ -215,6 +244,8 @@ internal class LivenessWebSocket(
}

private fun startWebSocket(okHttpClient: OkHttpClient, url: String) {
closeExpected = false

okHttpClient.newWebSocket(
Request.Builder().url(url).build(),
webSocketListener
Expand Down Expand Up @@ -274,14 +305,14 @@ internal class LivenessWebSocket(
videoStartTime: Long
) {
// Send initial ClientSessionInformationEvent
videoStartTimestamp = videoStartTime
videoStartTimestamp = videoStartTime + offset
initialDetectedFace = BoundingBox(
left = initialFaceRect.left / sessionInformation.videoWidth,
top = initialFaceRect.top / sessionInformation.videoHeight,
height = initialFaceRect.height() / sessionInformation.videoHeight,
width = initialFaceRect.width() / sessionInformation.videoWidth
)
faceDetectedStart = videoStartTime
faceDetectedStart = videoStartTime + offset
val clientInfoEvent =
ClientSessionInformationEvent(
challenge = ClientChallenge(
Expand Down Expand Up @@ -309,8 +340,8 @@ internal class LivenessWebSocket(
initialFaceDetectedTimestamp = faceDetectedStart
),
targetFace = TargetFace(
faceDetectedInTargetPositionStartTimestamp = faceMatchedStart,
faceDetectedInTargetPositionEndTimestamp = faceMatchedEnd,
faceDetectedInTargetPositionStartTimestamp = faceMatchedStart + offset,
faceDetectedInTargetPositionEndTimestamp = faceMatchedEnd + offset,
boundingBox = BoundingBox(
left = targetFaceRect.left / sessionInformation.videoWidth,
top = targetFaceRect.top / sessionInformation.videoHeight,
Expand Down Expand Up @@ -338,7 +369,7 @@ internal class LivenessWebSocket(
currentColor = currentColor,
previousColor = previousColor,
sequenceNumber = sequenceNumber,
currentColorStartTimestamp = colorStartTime
currentColorStartTimestamp = colorStartTime + offset
tjleing marked this conversation as resolved.
Show resolved Hide resolved
)
)
)
Expand All @@ -358,7 +389,7 @@ internal class LivenessWebSocket(
":content-type" to "application/json"
)
)
val eventDate = Date()
val eventDate = Date(Date().time + offset)
val signedPayload = signer.getSignedFrame(
region,
encodedPayload.array(),
Expand All @@ -381,12 +412,12 @@ internal class LivenessWebSocket(

fun sendVideoEvent(videoBytes: ByteArray, videoEventTime: Long) {
if (videoBytes.isNotEmpty()) {
videoEndTimestamp = videoEventTime
videoEndTimestamp = videoEventTime + offset
}
credentials?.let {
val videoBuffer = ByteBuffer.wrap(videoBytes)
val videoEvent = VideoEvent(
timestampMillis = videoEventTime,
timestampMillis = videoEventTime + offset,
videoChunk = videoBuffer
)
val videoJsonString = Json.encodeToString(videoEvent)
Expand All @@ -399,7 +430,7 @@ internal class LivenessWebSocket(
":content-type" to "application/json"
)
)
val videoEventDate = Date()
val videoEventDate = Date(Date().time + offset)
val signedVideoPayload = signer.getSignedFrame(
region,
encodedVideoPayload.array(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amplifyframework.predictions.aws.models.liveness

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

/**
* Constructs a new AccessDeniedException with the specified error message.
tjleing marked this conversation as resolved.
Show resolved Hide resolved
*
* @param message Describes the error encountered.
*/
@Serializable
internal data class InvalidSignatureException(
@SerialName("Message") override val message: String
) : Exception(message)
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ internal data class LivenessResponseStream(
ServiceQuotaExceededException? = null,
@SerialName("ServiceUnavailableException") val serviceUnavailableException: ServiceUnavailableException? = null,
@SerialName("SessionNotFoundException") val sessionNotFoundException: SessionNotFoundException? = null,
@SerialName("AccessDeniedException") val accessDeniedException: AccessDeniedException? = null
@SerialName("AccessDeniedException") val accessDeniedException: AccessDeniedException? = null,
@SerialName("InvalidSignatureException") val invalidSignatureException: InvalidSignatureException? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,18 @@ import com.amplifyframework.predictions.aws.models.liveness.ServerSessionInforma
import com.amplifyframework.predictions.aws.models.liveness.SessionInformation
import com.amplifyframework.predictions.aws.models.liveness.ValidationException
import com.amplifyframework.predictions.models.FaceLivenessSessionInformation
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkConstructor
import io.mockk.verify
import java.net.URL
import java.text.SimpleDateFormat
import java.util.Date
import java.util.Locale
import java.util.TimeZone
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import kotlin.math.abs
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.resetMain
Expand All @@ -48,6 +55,7 @@ import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import mockwebserver3.MockResponse
import mockwebserver3.MockWebServer
import okhttp3.Headers
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
Expand Down Expand Up @@ -317,6 +325,57 @@ internal class LivenessWebSocketTest {
assertEquals(livenessWebSocket.getUserAgent(), "$baseline $additional")
}

@Test
fun `web socket detects clock skew from server response`() {
val livenessWebSocket = createLivenessWebSocket()
mockkConstructor(WebSocket::class)
val socket: WebSocket = mockk()
livenessWebSocket.webSocket = socket
val sdf = SimpleDateFormat(livenessWebSocket.datePattern, Locale.US)

// server responds saying time is actually 1 hour in the future
val oneHour = 1000 * 3600
val futureDate = sdf.format(Date(Date().time + oneHour))
val response: Response = mockk()
every { response.headers }.returns(Headers.headersOf("Date", futureDate))
every { response.header("Date") }.returns(futureDate)
livenessWebSocket.webSocketListener.onOpen(socket, response)

// now we should restart the websocket with an adjusted time
val openLatch = CountDownLatch(1)
val latchingListener = LatchingWebSocketResponseListener(
livenessWebSocket.webSocketListener,
openLatch = openLatch
)
livenessWebSocket.webSocketListener = latchingListener

server.enqueue(MockResponse().withWebSocketUpgrade(ServerWebSocketListener()))
server.start()
livenessWebSocket.start()

openLatch.await(3, TimeUnit.SECONDS)

assertTrue(livenessWebSocket.webSocket != null)
val originalRequest = livenessWebSocket.webSocket!!.request()

// make sure that followup request sends offset date
val sdfGMT = SimpleDateFormat("yyyyMMdd'T'HHmmss'Z'", Locale.US)
sdfGMT.timeZone = TimeZone.getTimeZone("GMT")
val sentDate = originalRequest.url.queryParameter("X-Amz-Date") ?.let { sdfGMT.parse(it) }
val diff = abs(Date().time - sentDate?.time!!)
assert(oneHour - 10000 < diff && diff < oneHour + 10000)
tjleing marked this conversation as resolved.
Show resolved Hide resolved

// also make sure that followup request is valid
assertEquals("AWS4-HMAC-SHA256", originalRequest.url.queryParameter("X-Amz-Algorithm"))
assertTrue(
originalRequest.url.queryParameter("X-Amz-Credential")!!.endsWith("//rekognition/aws4_request")
)
assertEquals("299", originalRequest.url.queryParameter("X-Amz-Expires"))
assertEquals("host", originalRequest.url.queryParameter("X-Amz-SignedHeaders"))
assertEquals("AWS4-HMAC-SHA256", originalRequest.url.queryParameter("X-Amz-Algorithm"))
tjleing marked this conversation as resolved.
Show resolved Hide resolved
assertNotNull("x-amz-user-agent")
tjleing marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
@Ignore("Need to work on parsing the onMessage byteString from ServerWebSocketListener")
fun `sendInitialFaceDetectedEvent test`() {
Expand Down
4 changes: 2 additions & 2 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ androidx-test-junit = "1.1.2"
androidx-test-orchestrator = "1.3.0"
androidx-test-runner = "1.3.0"
androidx-workmanager = "2.7.1"
aws-kotlin = "0.29.1-beta" # ensure proper aws-smithy version also set
gpanshu marked this conversation as resolved.
Show resolved Hide resolved
aws-kotlin = "0.33.0-beta" # ensure proper aws-smithy version also set
aws-sdk = "2.62.2"
aws-smithy = "0.25.0" # ensure proper aws-kotlin version also set
aws-smithy = "0.28.0" # ensure proper aws-kotlin version also set
coroutines = "1.6.3"
desugar = "1.2.0"
espresso = "3.3.0"
Expand Down
Loading