Skip to content

Commit

Permalink
chore(liveness): Custom websocket handler (#4371)
Browse files Browse the repository at this point in the history
* WIP

* chore: websocket handler wip

* chore(liveness): add custom websocket fetch handler

* Create soft-singers-study.md

* chore: address comments

* chore: add tests and test coverage for custom websocket handler

* fix unused variables

* fix more unused variables
  • Loading branch information
thaddmt authored Aug 31, 2023
1 parent 2d5b380 commit 41a7b42
Show file tree
Hide file tree
Showing 12 changed files with 723 additions and 87 deletions.
5 changes: 5 additions & 0 deletions .changeset/soft-singers-study.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@aws-amplify/ui-react-liveness": patch
---

chore(liveness): Custom websocket handler
3 changes: 3 additions & 0 deletions packages/react-liveness/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"@aws-amplify/ui": "5.8.0",
"@aws-amplify/ui-react": "5.3.0",
"@aws-sdk/client-rekognitionstreaming": "3.360.0",
"@smithy/eventstream-serde-browser": "^2.0.4",
"@tensorflow-models/blazeface": "0.0.7",
"@tensorflow/tfjs-backend-cpu": "3.11.0",
"@tensorflow/tfjs-backend-wasm": "3.11.0",
Expand Down Expand Up @@ -78,7 +79,9 @@
"eslint": "^8.44.0",
"jest": "^27.0.4",
"jest-canvas-mock": "^2.4.0",
"jest-websocket-mock": "^2.4.1",
"jest-when": "^3.5.1",
"mock-socket": "^9.2.1",
"react": "^17.0.2",
"react-dom": "^17.0.2",
"rimraf": "^3.0.2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ import {
} from '../displayText';
import { LandscapeErrorModal } from '../shared/LandscapeErrorModal';
import { CheckScreenComponents } from '../shared/FaceLivenessErrorModal';
import { selectErrorState } from '../shared';

const CHECK_CLASS_NAME = 'liveness-detector-check';

export const selectErrorState = createLivenessSelector(
(state) => state.context.errorState
);
export const selectIsRecordingStopped = createLivenessSelector(
const selectIsRecordingStopped = createLivenessSelector(
(state) => state.context.isRecordingStopped
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import {
LivenessResponseStream,
} from '@aws-sdk/client-rekognitionstreaming';
import { STATIC_VIDEO_CONSTRAINTS } from '../../StartLiveness/helpers';
import { WS_CLOSURE_CODE } from '../utils/constants';

export const MIN_FACE_MATCH_TIME = 500;

Expand Down Expand Up @@ -758,7 +759,22 @@ export const livenessMachine = createMachine<LivenessContext, LivenessEvent>(
if (freshnessColorEl) {
freshnessColorEl.style.display = 'none';
}
await context.livenessStreamProvider?.endStream();

let closureCode = WS_CLOSURE_CODE.DEFAULT_ERROR_CODE;
if (context.errorState === LivenessErrorState.TIMEOUT) {
closureCode = WS_CLOSURE_CODE.FACE_FIT_TIMEOUT;
} else if (context.errorState === LivenessErrorState.RUNTIME_ERROR) {
closureCode = WS_CLOSURE_CODE.RUNTIME_ERROR;
} else if (
context.errorState === LivenessErrorState.FACE_DISTANCE_ERROR ||
context.errorState === LivenessErrorState.MULTIPLE_FACES_ERROR
) {
closureCode = WS_CLOSURE_CODE.USER_ERROR_DURING_CONNECTION;
} else if (context.errorState === undefined) {
closureCode = WS_CLOSURE_CODE.USER_CANCEL;
}

await context.livenessStreamProvider?.endStreamWithCode(closureCode);
},
freezeStream: async (context) => {
const { videoMediaStream, videoEl } = context.videoAssociatedParams!;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
/**
* Note: This file was copied from https://github.com/aws/aws-sdk-js-v3/blob/main/packages/middleware-websocket/src/websocket-fetch-handler.ts#L176
* Because of this the file is not fully typed at this time but we should eventually work on fully typing this file.
*/
/* eslint-disable @typescript-eslint/no-unsafe-argument */
/* eslint-disable @typescript-eslint/require-await */
/* eslint-disable @typescript-eslint/no-unsafe-call */
/* eslint-disable @typescript-eslint/no-unsafe-return */
/* eslint-disable @typescript-eslint/no-unsafe-member-access */
/* eslint-disable @typescript-eslint/no-unsafe-assignment */
import { formatUrl } from '@aws-sdk/util-format-url';
import {
iterableToReadableStream,
readableStreamtoIterable,
} from '@smithy/eventstream-serde-browser';
import { FetchHttpHandler } from '@smithy/fetch-http-handler';
import { HttpRequest, HttpResponse } from '@smithy/protocol-http';
import {
Provider,
RequestHandler,
RequestHandlerMetadata,
} from '@smithy/types';
import { WS_CLOSURE_CODE } from './constants';

const DEFAULT_WS_CONNECTION_TIMEOUT_MS = 2000;

const isWebSocketRequest = (request: HttpRequest) =>
request.protocol === 'ws:' || request.protocol === 'wss:';

const isReadableStream = (payload: any): payload is ReadableStream =>
typeof ReadableStream === 'function' && payload instanceof ReadableStream;

/**
* Transfer payload data to an AsyncIterable.
* When the ReadableStream API is available in the runtime(e.g. browser), and
* the request body is ReadableStream, so we need to transfer it to AsyncIterable
* to make the stream consumable by WebSocket.
*/
const getIterator = (stream: any): AsyncIterable<any> => {
// Noop if stream is already an async iterable
if (stream[Symbol.asyncIterator]) {
return stream;
}

if (isReadableStream(stream)) {
//If stream is a ReadableStream, transfer the ReadableStream to async iterable.
return readableStreamtoIterable(stream);
}

//For other types, just wrap them with an async iterable.
return {
[Symbol.asyncIterator]: async function* () {
yield stream;
},
};
};

/**
* Convert async iterable to a ReadableStream when ReadableStream API
* is available(browsers). Otherwise, leave as it is(ReactNative).
*/
const toReadableStream = <T>(asyncIterable: AsyncIterable<T>) =>
typeof ReadableStream === 'function'
? iterableToReadableStream(asyncIterable)
: asyncIterable;

export interface WebSocketFetchHandlerOptions {
/**
* The maximum time in milliseconds that the connection phase of a request
* may take before the connection attempt is abandoned.
*/
connectionTimeout?: number;
}

/**
* Base handler for websocket requests and HTTP request. By default, the request input and output
* body will be in a ReadableStream, because of interface consistency among middleware.
* If ReadableStream is not available, like in React-Native, the response body
* will be an async iterable.
*/
export class CustomWebSocketFetchHandler {
public readonly metadata: RequestHandlerMetadata = {
handlerProtocol: 'websocket/h1.1',
};
private readonly configPromise: Promise<WebSocketFetchHandlerOptions>;
private readonly httpHandler: RequestHandler<any, any>;
private readonly sockets: Record<string, WebSocket[]> = {};
private readonly utf8decoder = new TextDecoder(); // default 'utf-8' or 'utf8'

constructor(
options?:
| WebSocketFetchHandlerOptions
| Provider<WebSocketFetchHandlerOptions>,
httpHandler: RequestHandler<any, any> = new FetchHttpHandler()
) {
this.httpHandler = httpHandler;
if (typeof options === 'function') {
this.configPromise = options().then((opts) => opts ?? {});
} else {
this.configPromise = Promise.resolve(options ?? {});
}
}

/**
* Destroys the WebSocketHandler.
* Closes all sockets from the socket pool.
*/
destroy(): void {
for (const [key, sockets] of Object.entries(this.sockets)) {
for (const socket of sockets) {
socket.close(1000, `Socket closed through destroy() call`);
}
delete this.sockets[key];
}
}

async handle(request: HttpRequest): Promise<{ response: HttpResponse }> {
if (!isWebSocketRequest(request)) {
return this.httpHandler.handle(request);
}
const url = formatUrl(request);
const socket: WebSocket = new WebSocket(url);

// Add socket to sockets pool
if (!this.sockets[url]) {
this.sockets[url] = [];
}
this.sockets[url].push(socket);

socket.binaryType = 'arraybuffer';
const { connectionTimeout = DEFAULT_WS_CONNECTION_TIMEOUT_MS } = await this
.configPromise;
await this.waitForReady(socket, connectionTimeout);
const { body } = request;
const bodyStream = getIterator(body);
const asyncIterable = this.connect(socket, bodyStream);
const outputPayload = toReadableStream(asyncIterable);
return {
response: new HttpResponse({
statusCode: 200, // indicates connection success
body: outputPayload,
}),
};
}

/**
* Removes all closing/closed sockets from the socket pool for URL.
*/
private removeNotUsableSockets(url: string): void {
this.sockets[url] = (this.sockets[url] ?? []).filter(
(socket) =>
![WebSocket.CLOSING, WebSocket.CLOSED].includes(
socket.readyState as 2 | 3
)
);
}

private waitForReady(
socket: WebSocket,
connectionTimeout: number
): Promise<void> {
return new Promise((resolve, reject) => {
const timeout = setTimeout(() => {
this.removeNotUsableSockets(socket.url);
reject({
$metadata: {
httpStatusCode: 500,
},
});
}, connectionTimeout);

socket.onopen = () => {
clearTimeout(timeout);
resolve();
};
});
}

private connect(
socket: WebSocket,
data: AsyncIterable<Uint8Array>
): AsyncIterable<Uint8Array> {
// To notify output stream any error thrown after response
// is returned while data keeps streaming.
let streamError: Error | undefined = undefined;

// To notify onclose event that error has occurred.
let socketErrorOccurred = false;

// initialize as no-op.
let reject: (err?: unknown) => void = () => {};
let resolve: ({
done,
value,
}: {
done: boolean;
value: Uint8Array;
}) => void = () => {};

socket.onmessage = (event) => {
resolve({
done: false,
value: new Uint8Array(event.data),
});
};

socket.onerror = (error) => {
socketErrorOccurred = true;
socket.close();
reject(error);
};

socket.onclose = () => {
this.removeNotUsableSockets(socket.url);
if (socketErrorOccurred) return;

if (streamError) {
reject(streamError);
} else {
resolve({
done: true,
value: undefined as any, // unchecked because done=true.
});
}
};

const outputStream: AsyncIterable<Uint8Array> = {
[Symbol.asyncIterator]: () => ({
next: () => {
return new Promise((_resolve, _reject) => {
resolve = _resolve;
reject = _reject;
});
},
}),
};

const send = async (): Promise<void> => {
try {
for await (const inputChunk of data) {
const decodedString = this.utf8decoder.decode(inputChunk);
if (decodedString.includes('closeCode')) {
const match = decodedString.match(/"closeCode":([0-9]*)/);
if (match) {
const closeCode = match[1];
socket.close(parseInt(closeCode));
}
continue;
}

socket.send(inputChunk);
}
} catch (err) {
// We don't throw the error here because the send()'s returned
// would already be settled by the time sending chunk throws error.
// Instead, the notify the output stream to throw if there's
// exceptions
streamError = err as Error | undefined;
} finally {
// WS status code: https://tools.ietf.org/html/rfc6455#section-7.4
socket.close(WS_CLOSURE_CODE.SUCCESS_CODE);
}
};

send();

return outputStream;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export const mockOvalDetails: LivenessOvalDetails = {
};
export const mockLivenessStreamProvider: any = {
sendClientInfo: jest.fn(),
endStream: jest.fn(),
endStreamWithCode: jest.fn(),
stopVideo: jest.fn(),
dispatchStopVideoEvent: jest.fn(),
getResponseStream: jest.fn().mockResolvedValue([mockedStream]), // TODO: a following PR after PR634 will be made to have the stream emit the proper mock data.
Expand Down
Loading

0 comments on commit 41a7b42

Please sign in to comment.