From 5d54b5d156cec85d529d131e4ca600f79a0717dc Mon Sep 17 00:00:00 2001 From: Abhash Kumar Singh Date: Mon, 19 Aug 2024 11:28:31 -0700 Subject: [PATCH] feat(predictions): add web socket retry for clock skew (#3816) * feat(predictions): add web socket retry for clock skew * address review comments --- .../AWSTranscribeStreamingAdapter.swift | 10 +-- .../Service/FaceLivenessSession.swift | 63 +++++++++++++----- .../Liveness/Service/WebSocketSession.swift | 66 +++++++++++++++---- 3 files changed, 105 insertions(+), 34 deletions(-) diff --git a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Dependency/AWSTranscribeStreamingAdapter.swift b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Dependency/AWSTranscribeStreamingAdapter.swift index ec0cbeb219..5c70b79948 100644 --- a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Dependency/AWSTranscribeStreamingAdapter.swift +++ b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Dependency/AWSTranscribeStreamingAdapter.swift @@ -131,17 +131,17 @@ class AWSTranscribeStreamingAdapter: AWSTranscribeStreamingBehavior { continuation.yield(transcribedPayload) let isPartial = transcribedPayload.transcript?.results?.map(\.isPartial) ?? [] let shouldContinue = isPartial.allSatisfy { $0 } - return shouldContinue + return shouldContinue ? .continueToReceive : .stopAndInvalidateSession } catch { - return true + return .continueToReceive } case .success(.string): - return true + return .continueToReceive case .failure(let error): continuation.finish(throwing: error) - return false + return .stopAndInvalidateSession @unknown default: - return true + return .continueToReceive } } } diff --git a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/FaceLivenessSession.swift b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/FaceLivenessSession.swift index ffc2a9abac..092532d8e3 100644 --- a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/FaceLivenessSession.swift +++ b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/FaceLivenessSession.swift @@ -17,6 +17,14 @@ public final class FaceLivenessSession: LivenessService { let baseURL: URL var serverEventListeners: [LivenessEventKind.Server: (FaceLivenessSession.SessionConfiguration) -> Void] = [:] var onComplete: (ServerDisconnection) -> Void = { _ in } + var serverDate: Date? + var savedURLForReconnect: URL? + var connectingState: ConnectingState = .normal + + enum ConnectingState { + case normal + case reconnect + } private let livenessServiceDispatchQueue = DispatchQueue( label: "com.amazon.aws.amplify.liveness.service", @@ -35,12 +43,16 @@ public final class FaceLivenessSession: LivenessService { self.websocket = websocket websocket.onMessageReceived { [weak self] result in - self?.receive(result: result) ?? false + self?.receive(result: result) ?? .stopAndInvalidateSession } websocket.onSocketClosed { [weak self] closeCode in self?.onComplete(.unexpectedClosure(closeCode)) } + + websocket.onServerDateReceived { [weak self] serverDate in + self?.serverDate = serverDate + } } public var onServiceException: (FaceLivenessSessionError) -> Void = { _ in } @@ -75,6 +87,7 @@ public final class FaceLivenessSession: LivenessService { guard let url = components?.url else { throw FaceLivenessSessionError.invalidURL } + savedURLForReconnect = url let signedConnectionURL = signer.sign(url: url) websocket.open(url: signedConnectionURL) } @@ -93,17 +106,22 @@ public final class FaceLivenessSession: LivenessService { ] ) - let eventDate = eventDate() + let dateForSigning: Date + if let serverDate = serverDate { + dateForSigning = serverDate + } else { + dateForSigning = eventDate() + } let signedPayload = self.signer.signWithPreviousSignature( payload: encodedPayload, - dateHeader: (key: ":date", value: eventDate) + dateHeader: (key: ":date", value: dateForSigning) ) let encodedEvent = self.eventStreamEncoder.encode( payload: encodedPayload, headers: [ - ":date": .timestamp(eventDate), + ":date": .timestamp(dateForSigning), ":chunk-signature": .data(signedPayload) ] ) @@ -115,7 +133,7 @@ public final class FaceLivenessSession: LivenessService { } } - private func fallbackDecoding(_ message: EventStream.Message) -> Bool { + private func fallbackDecoding(_ message: EventStream.Message) -> WebSocketSession.WebSocketMessageResult { // We only care about two events above. // Just in case the header value changes (it shouldn't) // We'll try to decode each of these events @@ -124,12 +142,12 @@ public final class FaceLivenessSession: LivenessService { self.serverEventListeners[.challenge]?(sessionConfiguration) } else if (try? JSONDecoder().decode(DisconnectEvent.self, from: message.payload)) != nil { onComplete(.disconnectionEvent) - return false + return .stopAndInvalidateSession } - return true + return .continueToReceive } - private func receive(result: Result) -> Bool { + private func receive(result: Result) -> WebSocketSession.WebSocketMessageResult { switch result { case .success(.data(let data)): do { @@ -145,28 +163,41 @@ public final class FaceLivenessSession: LivenessService { ) let sessionConfiguration = sessionConfiguration(from: payload) serverEventListeners[.challenge]?(sessionConfiguration) - return true + return .continueToReceive case .disconnect: // :event-type DisconnectionEvent onComplete(.disconnectionEvent) - return false + return .stopAndInvalidateSession default: - return true + return .continueToReceive } } else if let exceptionType = message.headers.first(where: { $0.name == ":exception-type" }) { let exceptionEvent = LivenessEventKind.Exception(rawValue: exceptionType.value) - onServiceException(.init(event: exceptionEvent)) - return false + Amplify.log.verbose("\(#function): Received exception: \(exceptionEvent)") + guard exceptionEvent == .invalidSignature, + connectingState == .normal, + let savedURLForReconnect = savedURLForReconnect, + let serverDate = serverDate else { + onServiceException(.init(event: exceptionEvent)) + return .stopAndInvalidateSession + } + + connectingState = .reconnect + let signedConnectionURL = signer.sign( + url: savedURLForReconnect, + date: { serverDate } + ) + return .invalidateSessionAndRetry(url: signedConnectionURL) } else { return fallbackDecoding(message) } } catch { - return false + return .stopAndInvalidateSession } case .success: - return true + return .continueToReceive case .failure: - return false + return .stopAndInvalidateSession } } } diff --git a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/WebSocketSession.swift b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/WebSocketSession.swift index c0c96f3c17..f900f3dfc3 100644 --- a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/WebSocketSession.swift +++ b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Liveness/Service/WebSocketSession.swift @@ -6,13 +6,15 @@ // import Foundation +import Amplify final class WebSocketSession { private let urlSessionWebSocketDelegate: Delegate private let session: URLSession private var task: URLSessionWebSocketTask? - private var receiveMessage: ((Result) -> Bool)? + private var receiveMessage: ((Result) -> WebSocketMessageResult)? private var onSocketClosed: ((URLSessionWebSocketTask.CloseCode) -> Void)? + private var onServerDateReceived: ((Date?) -> Void)? init() { self.urlSessionWebSocketDelegate = Delegate() @@ -23,7 +25,7 @@ final class WebSocketSession { ) } - func onMessageReceived(_ receive: @escaping (Result) -> Bool) { + func onMessageReceived(_ receive: @escaping (Result) -> WebSocketMessageResult) { self.receiveMessage = receive } @@ -34,25 +36,32 @@ final class WebSocketSession { func onSocketOpened(_ onOpen: @escaping () -> Void) { urlSessionWebSocketDelegate.onOpen = onOpen } + + func onServerDateReceived(_ onServerDateReceived: @escaping (Date?) -> Void) { + urlSessionWebSocketDelegate.onServerDateReceived = onServerDateReceived + } - func receive(shouldContinue: Bool) { - guard shouldContinue else { + func receive(result: WebSocketMessageResult) { + switch result { + case .continueToReceive: + task?.receive(completionHandler: { [weak self] result in + if let webSocketResult = self?.receiveMessage?(result) { + self?.receive(result: webSocketResult) + } + }) + case .stopAndInvalidateSession: + session.finishTasksAndInvalidate() + case .invalidateSessionAndRetry(let url): session.finishTasksAndInvalidate() - return + open(url: url) } - - task?.receive(completionHandler: { [weak self] result in - if let shouldContinue = self?.receiveMessage?(result) { - self?.receive(shouldContinue: shouldContinue) - } - }) } func open(url: URL) { var request = URLRequest(url: url) request.setValue("no-store", forHTTPHeaderField: "Cache-Control") task = session.webSocketTask(with: request) - receive(shouldContinue: true) + receive(result: .continueToReceive) task?.resume() } @@ -77,10 +86,12 @@ final class WebSocketSession { ) } - final class Delegate: NSObject, URLSessionWebSocketDelegate { + final class Delegate: NSObject, URLSessionWebSocketDelegate, URLSessionTaskDelegate { var onClose: (URLSessionWebSocketTask.CloseCode) -> Void = { _ in } var onOpen: () -> Void = {} + var onServerDateReceived: (Date?) -> Void = { _ in } + // MARK: - URLSessionWebSocketDelegate methods func urlSession( _ session: URLSession, webSocketTask: URLSessionWebSocketTask, @@ -97,5 +108,34 @@ final class WebSocketSession { ) { onClose(closeCode) } + + // MARK: - URLSessionTaskDelegate methods + func urlSession(_ session: URLSession, + task: URLSessionTask, + didFinishCollecting metrics: URLSessionTaskMetrics + ) { + guard let httpResponse = metrics.transactionMetrics.first?.response as? HTTPURLResponse, + let dateString = httpResponse.value(forHTTPHeaderField: "Date") else { + Amplify.log.verbose("\(#function): Couldn't find Date header in URLSession metrics") + onServerDateReceived(nil) + return + } + + let dateFormatter = DateFormatter() + dateFormatter.dateFormat = "EEE, d MMM yyyy HH:mm:ss z" + guard let serverDate = dateFormatter.date(from: dateString) else { + Amplify.log.verbose("\(#function): Error parsing Date header in expected format") + onServerDateReceived(nil) + return + } + + onServerDateReceived(serverDate) + } + } + + enum WebSocketMessageResult { + case continueToReceive + case stopAndInvalidateSession + case invalidateSessionAndRetry(url: URL) } }