diff --git a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt index c81d6630b..1ebe93e24 100644 --- a/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt +++ b/aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt @@ -52,7 +52,11 @@ import java.util.Date import java.util.Locale import java.util.UUID import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.consumeEach import kotlinx.coroutines.launch import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json @@ -98,6 +102,15 @@ internal class LivenessWebSocket( internal var clientStoppedSession = false val json = Json { ignoreUnknownKeys = true } + // Sending events to the websocket requires processing synchronously because we rely on proper ordered + // prior signatures. When sending events, we send each of these events to an async queue to process 1 at a time. + private val sendEventScope = CoroutineScope(Job() + Dispatchers.IO) + private val sendEventQueueChannel = Channel(capacity = Channel.UNLIMITED).apply { + sendEventScope.launch { + consumeEach { it.join() } + } + } + @VisibleForTesting internal var webSocketListener = object : WebSocketListener() { override fun onOpen(webSocket: WebSocket, response: Response) { @@ -409,76 +422,87 @@ internal class LivenessWebSocket( } private fun sendClientInfoEvent(clientInfoEvent: ClientSessionInformationEvent) { - credentials?.let { - val jsonString = Json.encodeToString(clientInfoEvent) - val jsonPayload = jsonString.encodeUtf8().toByteArray() - val encodedPayload = LivenessEventStream.encode( - jsonPayload, - mapOf( - ":event-type" to "ClientSessionInformationEvent", - ":message-type" to "event", - ":content-type" to "application/json" - ) - ) - val eventDate = Date(adjustedDate()) - val signedPayload = signer.getSignedFrame( - region, - encodedPayload.array(), - it.secretAccessKey, - Pair(":date", eventDate) - ) - val signedPayloadBytes = signedPayload.chunked(2).map { hexChar -> hexChar.toInt(16).toByte() } - .toByteArray() - val encodedRequest = LivenessEventStream.encode( - encodedPayload.array(), - mapOf( - ":date" to eventDate, - ":chunk-signature" to signedPayloadBytes - ) - ) + // Add event to send queue to ensure proper ordering of signatures + sendEventQueueChannel.trySend( + sendEventScope.launch(start = CoroutineStart.LAZY) { + credentials?.let { + val jsonString = Json.encodeToString(clientInfoEvent) + val jsonPayload = jsonString.encodeUtf8().toByteArray() + val encodedPayload = LivenessEventStream.encode( + jsonPayload, + mapOf( + ":event-type" to "ClientSessionInformationEvent", + ":message-type" to "event", + ":content-type" to "application/json" + ) + ) + val eventDate = Date(adjustedDate()) + val signedPayload = signer.getSignedFrame( + region, + encodedPayload.array(), + it.secretAccessKey, + Pair(":date", eventDate) + ) + val signedPayloadBytes = signedPayload.chunked(2).map { hexChar -> + hexChar.toInt(16).toByte() + }.toByteArray() + val encodedRequest = LivenessEventStream.encode( + encodedPayload.array(), + mapOf( + ":date" to eventDate, + ":chunk-signature" to signedPayloadBytes + ) + ) - webSocket?.send(ByteString.of(*encodedRequest.array())) - } + webSocket?.send(ByteString.of(*encodedRequest.array())) + } + } + ) } fun sendVideoEvent(videoBytes: ByteArray, videoEventTime: Long) { - if (videoBytes.isNotEmpty()) { - videoEndTimestamp = adjustedDate(videoEventTime) - } - credentials?.let { - val videoBuffer = ByteBuffer.wrap(videoBytes) - val videoEvent = VideoEvent( - timestampMillis = adjustedDate(videoEventTime), - videoChunk = videoBuffer - ) - val videoJsonString = Json.encodeToString(videoEvent) - val videoJsonPayload = videoJsonString.encodeUtf8().toByteArray() - val encodedVideoPayload = LivenessEventStream.encode( - videoJsonPayload, - mapOf( - ":event-type" to "VideoEvent", - ":message-type" to "event", - ":content-type" to "application/json" - ) - ) - val videoEventDate = Date(adjustedDate()) - val signedVideoPayload = signer.getSignedFrame( - region, - encodedVideoPayload.array(), - it.secretAccessKey, - Pair(":date", videoEventDate) - ) - val signedVideoPayloadBytes = signedVideoPayload.chunked(2) - .map { hexChar -> hexChar.toInt(16).toByte() }.toByteArray() - val encodedVideoRequest = LivenessEventStream.encode( - encodedVideoPayload.array(), - mapOf( - ":date" to videoEventDate, - ":chunk-signature" to signedVideoPayloadBytes - ) - ) - webSocket?.send(ByteString.of(*encodedVideoRequest.array())) - } + // Add event to send queue to ensure proper ordering of signatures + sendEventQueueChannel.trySend( + sendEventScope.launch(start = CoroutineStart.LAZY) { + if (videoBytes.isNotEmpty()) { + videoEndTimestamp = adjustedDate(videoEventTime) + } + credentials?.let { + val videoBuffer = ByteBuffer.wrap(videoBytes) + val videoEvent = VideoEvent( + timestampMillis = adjustedDate(videoEventTime), + videoChunk = videoBuffer + ) + val videoJsonString = Json.encodeToString(videoEvent) + val videoJsonPayload = videoJsonString.encodeUtf8().toByteArray() + val encodedVideoPayload = LivenessEventStream.encode( + videoJsonPayload, + mapOf( + ":event-type" to "VideoEvent", + ":message-type" to "event", + ":content-type" to "application/json" + ) + ) + val videoEventDate = Date(adjustedDate()) + val signedVideoPayload = signer.getSignedFrame( + region, + encodedVideoPayload.array(), + it.secretAccessKey, + Pair(":date", videoEventDate) + ) + val signedVideoPayloadBytes = signedVideoPayload.chunked(2) + .map { hexChar -> hexChar.toInt(16).toByte() }.toByteArray() + val encodedVideoRequest = LivenessEventStream.encode( + encodedVideoPayload.array(), + mapOf( + ":date" to videoEventDate, + ":chunk-signature" to signedVideoPayloadBytes + ) + ) + webSocket?.send(ByteString.of(*encodedVideoRequest.array())) + } + } + ) } fun destroy(reasonCode: Int = NORMAL_SOCKET_CLOSURE_STATUS_CODE) {