Skip to content

Commit

Permalink
chore(predictions): add attempt count changes and unit tests (#3657)
Browse files Browse the repository at this point in the history
* chore(predictions): add attempt count changes and unit tests

* remove test url

* Add coding keys for Challenge object
  • Loading branch information
thisisabhash committed Aug 22, 2024
1 parent 15489e7 commit 4815461
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ extension AWSPredictionsPlugin {
withID sessionID: String,
credentialsProvider: AWSCredentialsProvider? = nil,
region: String,
options: FaceLivenessSession.Options,
completion: @escaping (Result<Void, FaceLivenessSessionError>) -> Void
) async throws -> FaceLivenessSession {

Expand All @@ -36,8 +35,7 @@ extension AWSPredictionsPlugin {
let session = FaceLivenessSession(
websocket: WebSocketSession(),
signer: signer,
baseURL: url,
options: options
baseURL: url
)

session.onServiceException = { completion(.failure($0)) }
Expand All @@ -49,14 +47,14 @@ extension AWSPredictionsPlugin {
extension FaceLivenessSession {
@_spi(PredictionsFaceLiveness)
public struct Options {
public let viewId: String
public let attemptCount: Int
public let preCheckViewEnabled: Bool

public init(
faceLivenessDetectorViewId: String,
attemptCount: Int,
preCheckViewEnabled: Bool
) {
self.viewId = faceLivenessDetectorViewId
self.attemptCount = attemptCount
self.preCheckViewEnabled = preCheckViewEnabled
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import Foundation

@_spi(PredictionsFaceLiveness)
public struct Challenge {
public struct Challenge: Codable {
public let version: String
public let type: ChallengeType

Expand All @@ -20,6 +20,11 @@ public struct Challenge {
public func queryParameterString() -> String {
return self.type.rawValue + "_" + self.version
}

enum CodingKeys: String, CodingKey {
case version = "Version"
case type = "Type"
}
}

@_spi(PredictionsFaceLiveness)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,12 @@ public final class FaceLivenessSession: LivenessService {
init(
websocket: WebSocketSession,
signer: SigV4Signer,
baseURL: URL,
options: FaceLivenessSession.Options
baseURL: URL
) {
self.eventStreamEncoder = EventStream.Encoder()
self.eventStreamDecoder = EventStream.Decoder()
self.signer = signer
self.baseURL = baseURL
self.options = options

self.websocket = websocket

Expand Down Expand Up @@ -84,14 +82,13 @@ public final class FaceLivenessSession: LivenessService {

public func initializeLivenessStream(withSessionID sessionID: String,
userAgent: String = "",
challenges: [Challenge] = FaceLivenessSession.supportedChallenges) throws {
challenges: [Challenge] = FaceLivenessSession.supportedChallenges,
options: FaceLivenessSession.Options) throws {
var components = URLComponents(url: baseURL, resolvingAgainstBaseURL: false)

components?.queryItems = [
URLQueryItem(name: "session-id", value: sessionID),
URLQueryItem(name: "precheck-view-enabled", value: options.preCheckViewEnabled ? "1":"0"),
// TODO: Change this after confirmation
URLQueryItem(name: "attempt-id", value: options.viewId),
URLQueryItem(name: "attempt-count", value: String(options.attemptCount)),
URLQueryItem(name: "challenge-versions",
value: challenges.map({$0.queryParameterString()}).joined(separator: ",")),
URLQueryItem(name: "video-width", value: "480"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ public protocol LivenessService {

func initializeLivenessStream(withSessionID sessionID: String,
userAgent: String,
challenges: [Challenge]) throws
challenges: [Challenge],
options: FaceLivenessSession.Options) throws

func register(
listener: @escaping (FaceLivenessSession.SessionConfiguration) -> Void,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import XCTest
import Amplify
@testable import AWSPredictionsPlugin
@_spi(PredictionsFaceLiveness) import AWSPredictionsPlugin

class LivenessChallengeTests: XCTestCase {

func testFaceMovementChallengeQueryParamterString() {
let challenge = Challenge(version: "1.0.0", type: .faceMovementChallenge)
XCTAssertEqual(challenge.queryParameterString(), "FaceMovementChallenge_1.0.0")
}

func testFaceMovementAndLightChallengeQueryParamterString() {
let challenge = Challenge(version: "2.0.0", type: .faceMovementAndLightChallenge)
XCTAssertEqual(challenge.queryParameterString(), "FaceMovementAndLightChallenge_2.0.0")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import XCTest
import Amplify
@testable import AWSPredictionsPlugin
@_spi(PredictionsFaceLiveness) import AWSPredictionsPlugin

class LivenessDecodingTests: XCTestCase {

// MARK: - ChallengeEvent
/// - Given: A valid json payload depicting a FaceMovementChallenge
/// - When: The payload is decoded
/// - Then: The payload is decoded successfully
func testFacemovementChallengeEventDecodeSuccess() {
let jsonString =
"""
{"Type":"FaceMovementChallenge","Version":"1.0.0"}
"""

do {
let data = jsonString.data(using: .utf8)
guard let data = data else {
XCTFail("Input JSON is invalid")
return
}
let challengeEvent = try JSONDecoder().decode(
ChallengeEvent.self, from: data
)

XCTAssertEqual(challengeEvent.type, ChallengeType.faceMovementChallenge)
XCTAssertEqual(challengeEvent.version, "1.0.0")
} catch {
XCTFail("Decoding failed with error: \(error)")
}
}

/// - Given: A valid json payload depicting a FaceMovementAndLightChallenge
/// - When: The payload is decoded
/// - Then: The payload is decoded successfully
func testFacemovementAndLightChallengeEventDecodeSuccess() {
let jsonString =
"""
{"Type":"FaceMovementAndLightChallenge","Version":"1.0.0"}
"""

do {
let data = jsonString.data(using: .utf8)
guard let data = data else {
XCTFail("Input JSON is invalid")
return
}
let challengeEvent = try JSONDecoder().decode(
ChallengeEvent.self, from: data
)

XCTAssertEqual(challengeEvent.type, ChallengeType.faceMovementAndLightChallenge)
XCTAssertEqual(challengeEvent.version, "1.0.0")
} catch {
XCTFail("Decoding failed with error: \(error)")
}
}

/// - Given: A valid json payload depicting an unknown challenge
/// - When: The payload is decoded
/// - Then: Error is thrown
func testUnknownChallengeEventDecodeFailure() {
let jsonString =
"""
{"Type":"UnknownChallenge","Version":"1.0.0"}
"""

do {
let data = jsonString.data(using: .utf8)
guard let data = data else {
XCTFail("Input JSON is invalid")
return
}
_ = try JSONDecoder().decode(
ChallengeEvent.self, from: data
)

XCTFail("Decoding should fail for unknown challenge")
} catch {
XCTAssertNotNil(error)
}
}

// MARK: - ServerSessionInformationEvent

/// - Given: A valid json payload depicting a ServerSessionInformation
/// containing FaceMovementChallenge
/// - When: The payload is decoded
/// - Then: The payload is decoded successfully
func testFaceMovementChallengeServerSessionInformationEventDecodeSuccess() {
let jsonString =
"""
{\"SessionInformation\":{\"Challenge\":{\"FaceMovementChallenge\":{\"OvalParameters\":{\"Width\":0.1,\"Height\":0.1,\"CenterY\":0.1,\"CenterX\":0.1},\"ChallengeConfig\":{\"BlazeFaceDetectionThreshold\":0.1,\"FaceIouHeightThreshold\":0.1,\"OvalHeightWidthRatio\":0.1,\"OvalIouHeightThreshold\":0.1,\"OvalFitTimeout\":1,\"OvalIouWidthThreshold\":0.1,\"OvalIouThreshold\":0.1,\"FaceDistanceThreshold\":0.1,\"FaceDistanceThresholdMax\":0.1,\"FaceIouWidthThreshold\":0.1,\"FaceDistanceThresholdMin\":0.1}}}}}
"""

do {
let data = jsonString.data(using: .utf8)
guard let data = data else {
XCTFail("Input JSON is invalid")
return
}
let serverSessionInformationEvent = try JSONDecoder().decode(
ServerSessionInformationEvent.self, from: data
)

guard case let .faceMovementChallenge(challenge: recoveredChallenge) =
serverSessionInformationEvent.sessionInformation.challenge.type else {
XCTFail("Cannot decode event from the input JSON")
return
}

XCTAssertEqual(recoveredChallenge.ovalParameters.height, 0.1)
XCTAssertEqual(recoveredChallenge.ovalParameters.width, 0.1)
XCTAssertEqual(recoveredChallenge.ovalParameters.centerX, 0.1)
XCTAssertEqual(recoveredChallenge.ovalParameters.centerY, 0.1)

XCTAssertEqual(recoveredChallenge.challengeConfig.blazeFaceDetectionThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceDistanceThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceDistanceThresholdMax, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceDistanceThresholdMin, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceIouHeightThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceIouWidthThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalHeightWidthRatio, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalIouHeightThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalIouThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalIouWidthThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalFitTimeout, 1)
} catch {
XCTFail("Decoding failed with error: \(error)")
}
}

/// - Given: A valid json payload depicting a ServerSessionInformation
/// containing FaceMovementAndLightChallenge
/// - When: The payload is decoded
/// - Then: The payload is decoded successfully
func testFaceMovementAndLightChallengeServerSessionInformationEventDecodeSuccess() {
let jsonString =
"""
{\"SessionInformation\":{\"Challenge\":{\"FaceMovementAndLightChallenge\":{\"OvalParameters\":{\"Height\":0.1,\"CenterX\":0.1,\"Width\":0.1,\"CenterY\":0.1},\"ColorSequences\":[{\"FreshnessColor\":{\"RGB\":[255,255,255]},\"DownscrollDuration\":0.1,\"FlatDisplayDuration\":0.1}],\"ChallengeConfig\":{\"OvalIouWidthThreshold\":0.1,\"FaceDistanceThreshold\":0.1,\"OvalFitTimeout\":1,\"FaceIouHeightThreshold\":0.1,\"FaceDistanceThresholdMax\":0.1,\"FaceDistanceThresholdMin\":0.1,\"OvalIouHeightThreshold\":0.1,\"FaceIouWidthThreshold\":0.1,\"OvalIouThreshold\":0.1,\"BlazeFaceDetectionThreshold\":0.1,\"OvalHeightWidthRatio\":0.1}}}}}
"""

do {
let data = jsonString.data(using: .utf8)
guard let data = data else {
XCTFail("Input JSON is invalid")
return
}
let serverSessionInformationEvent = try JSONDecoder().decode(
ServerSessionInformationEvent.self, from: data
)

guard case let .faceMovementAndLightChallenge(challenge: recoveredChallenge) =
serverSessionInformationEvent.sessionInformation.challenge.type else {
XCTFail("Cannot decode event from the input JSON")
return
}

XCTAssertEqual(recoveredChallenge.ovalParameters.height, 0.1)
XCTAssertEqual(recoveredChallenge.ovalParameters.width, 0.1)
XCTAssertEqual(recoveredChallenge.ovalParameters.centerX, 0.1)
XCTAssertEqual(recoveredChallenge.ovalParameters.centerY, 0.1)

XCTAssertEqual(recoveredChallenge.challengeConfig.blazeFaceDetectionThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceDistanceThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceDistanceThresholdMax, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceDistanceThresholdMin, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceIouHeightThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.faceIouWidthThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalHeightWidthRatio, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalIouHeightThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalIouThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalIouWidthThreshold, 0.1)
XCTAssertEqual(recoveredChallenge.challengeConfig.ovalFitTimeout, 1)

XCTAssertEqual(recoveredChallenge.colorSequences.count, 1)
XCTAssertEqual(recoveredChallenge.colorSequences.first?.downscrollDuration, 0.1)
XCTAssertEqual(recoveredChallenge.colorSequences.first?.flatDisplayDuration, 0.1)
XCTAssertEqual(recoveredChallenge.colorSequences.first?.freshnessColor.rgb, [255,255,255])
} catch {
XCTFail("Decoding failed with error: \(error)")
}
}

/// - Given: A valid json payload depicting a ServerSessionInformation
/// containing unknown challenge
/// - When: The payload is decoded
/// - Then: Error should be thrown
func testUnknownChallengeServerSessionInformationEventDecodeFailure() {
let jsonString =
"""
{\"SessionInformation\":{\"Challenge\":{\"UnknownChallenge\":{\"OvalParameters\":{\"Height\":0.1,\"CenterX\":0.1,\"Width\":0.1,\"CenterY\":0.1},\"ColorSequences\":[{\"FreshnessColor\":{\"RGB\":[255,255,255]},\"DownscrollDuration\":0.1,\"FlatDisplayDuration\":0.1}],\"ChallengeConfig\":{\"OvalIouWidthThreshold\":0.1,\"FaceDistanceThreshold\":0.1,\"OvalFitTimeout\":1,\"FaceIouHeightThreshold\":0.1,\"FaceDistanceThresholdMax\":0.1,\"FaceDistanceThresholdMin\":0.1,\"OvalIouHeightThreshold\":0.1,\"FaceIouWidthThreshold\":0.1,\"OvalIouThreshold\":0.1,\"BlazeFaceDetectionThreshold\":0.1,\"OvalHeightWidthRatio\":0.1}}}}}
"""

do {
let data = jsonString.data(using: .utf8)
guard let data = data else {
XCTFail("Input JSON is invalid")
return
}
let serverSessionInformationEvent = try JSONDecoder().decode(
ServerSessionInformationEvent.self, from: data
)

XCTFail("Decoding should fail for unknown challenge")
} catch {
XCTAssertNotNil(error)
}
}
}

0 comments on commit 4815461

Please sign in to comment.