Skip to content

Commit

Permalink
feat(predictions): add web socket retry for clock skew (#3816)
Browse files Browse the repository at this point in the history
* feat(predictions): add web socket retry for clock skew

* address review comments
  • Loading branch information
thisisabhash authored Aug 19, 2024
1 parent 611368c commit 5d54b5d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,17 @@ class AWSTranscribeStreamingAdapter: AWSTranscribeStreamingBehavior {
continuation.yield(transcribedPayload)
let isPartial = transcribedPayload.transcript?.results?.map(\.isPartial) ?? []
let shouldContinue = isPartial.allSatisfy { $0 }
return shouldContinue
return shouldContinue ? .continueToReceive : .stopAndInvalidateSession
} catch {
return true
return .continueToReceive
}
case .success(.string):
return true
return .continueToReceive
case .failure(let error):
continuation.finish(throwing: error)
return false
return .stopAndInvalidateSession
@unknown default:
return true
return .continueToReceive
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ public final class FaceLivenessSession: LivenessService {
let baseURL: URL
var serverEventListeners: [LivenessEventKind.Server: (FaceLivenessSession.SessionConfiguration) -> Void] = [:]
var onComplete: (ServerDisconnection) -> Void = { _ in }
var serverDate: Date?
var savedURLForReconnect: URL?
var connectingState: ConnectingState = .normal

enum ConnectingState {
case normal
case reconnect
}

private let livenessServiceDispatchQueue = DispatchQueue(
label: "com.amazon.aws.amplify.liveness.service",
Expand All @@ -35,12 +43,16 @@ public final class FaceLivenessSession: LivenessService {
self.websocket = websocket

websocket.onMessageReceived { [weak self] result in
self?.receive(result: result) ?? false
self?.receive(result: result) ?? .stopAndInvalidateSession
}

websocket.onSocketClosed { [weak self] closeCode in
self?.onComplete(.unexpectedClosure(closeCode))
}

websocket.onServerDateReceived { [weak self] serverDate in
self?.serverDate = serverDate
}
}

public var onServiceException: (FaceLivenessSessionError) -> Void = { _ in }
Expand Down Expand Up @@ -75,6 +87,7 @@ public final class FaceLivenessSession: LivenessService {
guard let url = components?.url
else { throw FaceLivenessSessionError.invalidURL }

savedURLForReconnect = url
let signedConnectionURL = signer.sign(url: url)
websocket.open(url: signedConnectionURL)
}
Expand All @@ -93,17 +106,22 @@ public final class FaceLivenessSession: LivenessService {
]
)

let eventDate = eventDate()
let dateForSigning: Date
if let serverDate = serverDate {
dateForSigning = serverDate
} else {
dateForSigning = eventDate()
}

let signedPayload = self.signer.signWithPreviousSignature(
payload: encodedPayload,
dateHeader: (key: ":date", value: eventDate)
dateHeader: (key: ":date", value: dateForSigning)
)

let encodedEvent = self.eventStreamEncoder.encode(
payload: encodedPayload,
headers: [
":date": .timestamp(eventDate),
":date": .timestamp(dateForSigning),
":chunk-signature": .data(signedPayload)
]
)
Expand All @@ -115,7 +133,7 @@ public final class FaceLivenessSession: LivenessService {
}
}

private func fallbackDecoding(_ message: EventStream.Message) -> Bool {
private func fallbackDecoding(_ message: EventStream.Message) -> WebSocketSession.WebSocketMessageResult {
// We only care about two events above.
// Just in case the header value changes (it shouldn't)
// We'll try to decode each of these events
Expand All @@ -124,12 +142,12 @@ public final class FaceLivenessSession: LivenessService {
self.serverEventListeners[.challenge]?(sessionConfiguration)
} else if (try? JSONDecoder().decode(DisconnectEvent.self, from: message.payload)) != nil {
onComplete(.disconnectionEvent)
return false
return .stopAndInvalidateSession
}
return true
return .continueToReceive
}

private func receive(result: Result<URLSessionWebSocketTask.Message, Error>) -> Bool {
private func receive(result: Result<URLSessionWebSocketTask.Message, Error>) -> WebSocketSession.WebSocketMessageResult {
switch result {
case .success(.data(let data)):
do {
Expand All @@ -145,28 +163,41 @@ public final class FaceLivenessSession: LivenessService {
)
let sessionConfiguration = sessionConfiguration(from: payload)
serverEventListeners[.challenge]?(sessionConfiguration)
return true
return .continueToReceive
case .disconnect:
// :event-type DisconnectionEvent
onComplete(.disconnectionEvent)
return false
return .stopAndInvalidateSession
default:
return true
return .continueToReceive
}
} else if let exceptionType = message.headers.first(where: { $0.name == ":exception-type" }) {
let exceptionEvent = LivenessEventKind.Exception(rawValue: exceptionType.value)
onServiceException(.init(event: exceptionEvent))
return false
Amplify.log.verbose("\(#function): Received exception: \(exceptionEvent)")
guard exceptionEvent == .invalidSignature,
connectingState == .normal,
let savedURLForReconnect = savedURLForReconnect,
let serverDate = serverDate else {
onServiceException(.init(event: exceptionEvent))
return .stopAndInvalidateSession
}

connectingState = .reconnect
let signedConnectionURL = signer.sign(
url: savedURLForReconnect,
date: { serverDate }
)
return .invalidateSessionAndRetry(url: signedConnectionURL)
} else {
return fallbackDecoding(message)
}
} catch {
return false
return .stopAndInvalidateSession
}
case .success:
return true
return .continueToReceive
case .failure:
return false
return .stopAndInvalidateSession
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
//

import Foundation
import Amplify

final class WebSocketSession {
private let urlSessionWebSocketDelegate: Delegate
private let session: URLSession
private var task: URLSessionWebSocketTask?
private var receiveMessage: ((Result<URLSessionWebSocketTask.Message, Error>) -> Bool)?
private var receiveMessage: ((Result<URLSessionWebSocketTask.Message, Error>) -> WebSocketMessageResult)?
private var onSocketClosed: ((URLSessionWebSocketTask.CloseCode) -> Void)?
private var onServerDateReceived: ((Date?) -> Void)?

init() {
self.urlSessionWebSocketDelegate = Delegate()
Expand All @@ -23,7 +25,7 @@ final class WebSocketSession {
)
}

func onMessageReceived(_ receive: @escaping (Result<URLSessionWebSocketTask.Message, Error>) -> Bool) {
func onMessageReceived(_ receive: @escaping (Result<URLSessionWebSocketTask.Message, Error>) -> WebSocketMessageResult) {
self.receiveMessage = receive
}

Expand All @@ -34,25 +36,32 @@ final class WebSocketSession {
func onSocketOpened(_ onOpen: @escaping () -> Void) {
urlSessionWebSocketDelegate.onOpen = onOpen
}

func onServerDateReceived(_ onServerDateReceived: @escaping (Date?) -> Void) {
urlSessionWebSocketDelegate.onServerDateReceived = onServerDateReceived
}

func receive(shouldContinue: Bool) {
guard shouldContinue else {
func receive(result: WebSocketMessageResult) {
switch result {
case .continueToReceive:
task?.receive(completionHandler: { [weak self] result in
if let webSocketResult = self?.receiveMessage?(result) {
self?.receive(result: webSocketResult)
}
})
case .stopAndInvalidateSession:
session.finishTasksAndInvalidate()
case .invalidateSessionAndRetry(let url):
session.finishTasksAndInvalidate()
return
open(url: url)
}

task?.receive(completionHandler: { [weak self] result in
if let shouldContinue = self?.receiveMessage?(result) {
self?.receive(shouldContinue: shouldContinue)
}
})
}

func open(url: URL) {
var request = URLRequest(url: url)
request.setValue("no-store", forHTTPHeaderField: "Cache-Control")
task = session.webSocketTask(with: request)
receive(shouldContinue: true)
receive(result: .continueToReceive)
task?.resume()
}

Expand All @@ -77,10 +86,12 @@ final class WebSocketSession {
)
}

final class Delegate: NSObject, URLSessionWebSocketDelegate {
final class Delegate: NSObject, URLSessionWebSocketDelegate, URLSessionTaskDelegate {
var onClose: (URLSessionWebSocketTask.CloseCode) -> Void = { _ in }
var onOpen: () -> Void = {}
var onServerDateReceived: (Date?) -> Void = { _ in }

// MARK: - URLSessionWebSocketDelegate methods
func urlSession(
_ session: URLSession,
webSocketTask: URLSessionWebSocketTask,
Expand All @@ -97,5 +108,34 @@ final class WebSocketSession {
) {
onClose(closeCode)
}

// MARK: - URLSessionTaskDelegate methods
func urlSession(_ session: URLSession,
task: URLSessionTask,
didFinishCollecting metrics: URLSessionTaskMetrics
) {
guard let httpResponse = metrics.transactionMetrics.first?.response as? HTTPURLResponse,
let dateString = httpResponse.value(forHTTPHeaderField: "Date") else {
Amplify.log.verbose("\(#function): Couldn't find Date header in URLSession metrics")
onServerDateReceived(nil)
return
}

let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "EEE, d MMM yyyy HH:mm:ss z"
guard let serverDate = dateFormatter.date(from: dateString) else {
Amplify.log.verbose("\(#function): Error parsing Date header in expected format")
onServerDateReceived(nil)
return
}

onServerDateReceived(serverDate)
}
}

enum WebSocketMessageResult {
case continueToReceive
case stopAndInvalidateSession
case invalidateSessionAndRetry(url: URL)
}
}

0 comments on commit 5d54b5d

Please sign in to comment.