Skip to content

Commit

Permalink
feat: Allow multiple model processors to increase the FPS (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
mirland authored Sep 30, 2024
1 parent 01bc15e commit d2c2edf
Show file tree
Hide file tree
Showing 27 changed files with 328 additions and 612 deletions.
51 changes: 10 additions & 41 deletions client/lib/core/common/config.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import 'package:camera/camera.dart';
import 'package:flutter/foundation.dart';
import 'package:package_info_plus/package_info_plus.dart';
import 'package:simon_ai/core/common/environments.dart';
import 'package:simon_ai/core/common/extension/string_extensions.dart';
import 'package:simon_ai/core/common/helper/enum_helpers.dart';
import 'package:simon_ai/core/common/helper/env_helper.dart';

interface class Config {
static const String environmentFolder = 'environments';
Expand All @@ -18,29 +16,22 @@ interface class Config {
static const debugMode = kDebugMode;
static const crashlyticsEnabled = !kIsWeb && !debugMode;

static late String apiBaseUrl;
static late String supabaseApiKey;

static const String userCollection = 'users';

static const cameraResolutionPreset = ResolutionPreset.medium;
static const numberOfProcessors = 3;

static final _environment = enumFromString(
Environments.values,
const String.fromEnvironment('ENV'),
) ??
Environments.dev;
static const logMlHandlers = true;
static const logMlHandlersVerbose = false;

static Future<void> initialize() async {
await _EnvConfig._setupEnv(_environment);
_initializeEnvVariables();
}
static late Environments environment;

static void _initializeEnvVariables() {
apiBaseUrl =
_EnvConfig.getEnvVariable(_EnvConfig.ENV_KEY_API_BASE_URL) ?? '';
supabaseApiKey =
_EnvConfig.getEnvVariable(_EnvConfig.ENV_KEY_SUPABASE_API_KEY) ?? '';
static Future<void> initialize() async {
environment = enumFromString(
Environments.values,
const String.fromEnvironment('ENV'),
) ??
await getEnvFromBundleId();
}

static Future<Environments> getEnvFromBundleId() async {
Expand All @@ -50,25 +41,3 @@ interface class Config {
: Environments.dev;
}
}

abstract class _EnvConfig {
static const ENV_KEY_API_BASE_URL = 'API_BASE_URL';
static const ENV_KEY_SUPABASE_API_KEY = 'SUPABASE_API_KEY';

static const systemEnv = {
ENV_KEY_API_BASE_URL: String.fromEnvironment(ENV_KEY_API_BASE_URL),
ENV_KEY_SUPABASE_API_KEY: String.fromEnvironment(ENV_KEY_SUPABASE_API_KEY),
};

static final Map<String, String> _envFileEnv = {};

static String? getEnvVariable(String key) =>
_EnvConfig.systemEnv[key].ifNullOrBlank(() => _envFileEnv[key]);

static Future<void> _setupEnv(Environments env) async {
_envFileEnv
..addAll(await loadEnvs('${Config.environmentFolder}/default.env'))
..addAll(await loadEnvs('${env.path}.env'))
..addAll(await loadEnvs('${env.path}.private.env'));
}
}
23 changes: 15 additions & 8 deletions client/lib/core/common/extension/interpreter_extensions.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,29 @@ import 'package:tflite_flutter/tflite_flutter.dart';

extension InterpreterExtensions on List<Interpreter> {
Interpreter get handDetectorInterpreter => this[1];

Interpreter get handTrackingInterpreter => first;

Interpreter get handGestureEmbedderInterpreter => this[2];

Interpreter get handCannedGestureInterpreter => this[3];
}

extension InterpreterOptionsExtensions on InterpreterOptions {
void defaultOptions() {
void defaultOptions(int index) {
if (Platform.isAndroid) {
addDelegate(
GpuDelegateV2(
options: GpuDelegateOptionsV2(
isPrecisionLossAllowed: false,
inferencePriority1: 2,
if (index % 2 == 0) {
addDelegate(
GpuDelegateV2(
options: GpuDelegateOptionsV2(
isPrecisionLossAllowed: false,
inferencePriority1: 2,
),
),
),
);
);
} else {
addDelegate(XNNPackDelegate());
}
}
}
}
2 changes: 2 additions & 0 deletions client/lib/core/common/extension/string_extensions.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// ignore_for_file: unused-code

extension StringExtensions on String {
bool get isUpperCase => contains(RegExp(r'^[A-Z]+$'));

Expand Down
12 changes: 0 additions & 12 deletions client/lib/core/common/helper/env_helper.dart

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import 'package:simon_ai/core/model/hand_gesture_with_position.dart';
import 'package:simon_ai/core/model/hand_gestures.dart';

const _logEnabled = true;
const _logVerbose = false;
const _logVerbose = true;

/// Transforms a stream of [HandGestureWithPosition] into a stream of game
/// [HandGestureWithPosition].
Expand All @@ -16,12 +16,12 @@ const _logVerbose = false;
class GameGestureStabilizationTransformer extends StreamTransformerBase<
HandGestureWithPosition, HandGestureWithPosition> {
static final _defaultTimeSpan = Platform.isAndroid
? const Duration(milliseconds: 500)
: const Duration(seconds: 1);
static const _defaultWindowSize = 5;
? const Duration(seconds: 1)
: const Duration(milliseconds: 300);
static const _defaultWindowSize = 7;
static final _defaultMinWindowSize = Platform.isAndroid ? 3 : 5;
static final _defaultMaxUnrecognizedGesturesInWindow =
Platform.isAndroid ? 5 : 3;
Platform.isAndroid ? 2 : 3;

final int _windowSize;
final int _minWindowSize;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import 'dart:async';
import 'dart:collection';

import 'package:mutex/mutex.dart';

typedef Processor<T, R> = FutureOr<R> Function(T);

class ProcessWhileAvailableTransformer<T, R>
extends StreamTransformerBase<T, R> {
var _isClosed = false;
final List<Processor> _availableProcessors;
final Set<Processor> _processors;

final _mutex = Mutex();
final Queue<T> _unprocessedQueue = Queue<T>();

ProcessWhileAvailableTransformer(Iterable<Processor> processors)
: _processors = processors.toSet(),
_availableProcessors = processors.toList();

Future<void> _processValue(
Processor processor,
T value,
EventSink<R> sink,
) async {
try {
T? currentValue = value;
do {
final event = await processor(currentValue);
sink.add(event);
// ignore: avoid-redundant-async
await _mutex.protect(() async {
currentValue = _unprocessedQueue.isEmpty
? null
: _unprocessedQueue.removeFirst();
});
} while (currentValue != null && !_isClosed);
} catch (e) {
if (!_isClosed) {
sink.addError(e);
}
}
}

void _handleData(T value, EventSink<R> sink) {
// ignore: avoid-redundant-async
_mutex.protect(() async {
if (_availableProcessors.isEmpty) {
_unprocessedQueue.addLast(value);
if (_unprocessedQueue.length > _processors.length) {
// Drop the oldest value when the queue is full
_unprocessedQueue.removeFirst();
}
} else {
final processor = _availableProcessors.removeAt(0);
unawaited(
_processValue(processor, value, sink).then(
(_) => _mutex.protect(
() async => _availableProcessors.add(processor),
),
),
);
}
});
}

@override
Stream<R> bind(Stream<T> stream) => stream.transform(
StreamTransformer<T, R>.fromHandlers(
handleData: _handleData,
handleError:
(Object error, StackTrace stackTrace, EventSink<R> sink) {
sink.addError(error, stackTrace);
},
handleDone: (EventSink<R> sink) {
sink.close();
_isClosed = true;
},
),
);
}
5 changes: 1 addition & 4 deletions client/lib/core/di/di_repository_module.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import 'package:simon_ai/core/services/firebase_auth.dart';
import 'package:simon_ai/core/services/permission_handler_service.dart';
import 'package:simon_ai/core/source/auth_local_source.dart';
import 'package:simon_ai/core/source/auth_remote_source.dart';
import 'package:simon_ai/core/source/common/auth_interceptor.dart';
import 'package:simon_ai/core/source/common/http_service.dart';
import 'package:simon_ai/core/source/user_remote_source.dart';

class RepositoryDiModule {
Expand All @@ -29,7 +27,6 @@ class RepositoryDiModule {

extension _GetItDiModuleExtensions on GetIt {
void _setupProvidersAndUtils() {
registerLazySingleton(() => HttpServiceDio([AuthInterceptor(get())]));
registerLazySingleton(() => FirebaseAuthService());
registerLazySingleton<PermissionHandlerInterface>(
() => MobilePermissionHandlerService(),
Expand All @@ -44,7 +41,7 @@ extension _GetItDiModuleExtensions on GetIt {

void _setupSources() {
registerLazySingleton(() => AuthLocalSource(get()));
registerLazySingleton(() => AuthRemoteSource(get(), get()));
registerLazySingleton(() => AuthRemoteSource(get()));

registerLazySingleton<UserRemoteSource>(() => UserRemoteSource());
}
Expand Down
11 changes: 8 additions & 3 deletions client/lib/core/di/di_utils_mobile_module.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import 'package:get_it/get_it.dart';
import 'package:simon_ai/core/common/config.dart';
import 'package:simon_ai/core/hand_models/keypoints/gesture_processor.dart';
import 'package:simon_ai/core/hand_models/keypoints/gesture_processor_mobile.dart';

Expand All @@ -17,9 +18,13 @@ class PlatformUtilsDiModule {

extension _GetItUseCaseDiModuleExtensions on GetIt {
void _setupUtilsModule() {
registerLazySingleton<GestureMobileProcessor>(GestureMobileProcessor.new);
registerLazySingleton<GestureProcessor>(
() => get<GestureMobileProcessor>(),
registerLazySingleton<GestureProcessorPool>(
() => GestureProcessorPool(
List<GestureMobileProcessor>.generate(
Config.numberOfProcessors,
(index) => GestureMobileProcessor(index),
),
),
);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import 'dart:math';

import 'package:simon_ai/core/common/config.dart';
import 'package:simon_ai/core/common/extension/interpreter_extensions.dart';
import 'package:simon_ai/core/common/logger.dart';
import 'package:simon_ai/core/interfaces/model_interface.dart';
Expand All @@ -11,9 +12,6 @@ import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

class HandCannedGestureClassifier
implements ModelHandler<TensorBufferFloat, HandGesture> {
final bool _logInit = true;
final bool _logResultTime = false;

final ModelMetadata model =
(path: Assets.models.cannedGestureClassifier, inputSize: 128);

Expand All @@ -26,15 +24,18 @@ class HandCannedGestureClassifier

final stopwatch = Stopwatch();

final int processorIndex;

HandCannedGestureClassifier({
required this.processorIndex,
Interpreter? interpreter,
}) {
loadModel(interpreter: interpreter);
}

@override
Future<Interpreter> createModelInterpreter() {
final options = InterpreterOptions()..defaultOptions();
final options = InterpreterOptions()..defaultOptions(processorIndex);
return Interpreter.fromAsset(model.path, options: options);
}

Expand All @@ -46,7 +47,7 @@ class HandCannedGestureClassifier
handCannedGestureOutputLocations = outputHandCannedGestureTensors
.map((e) => TensorBufferFloat(e.shape))
.toList();
if (_logInit && interpreter == null) {
if (Config.logMlHandlers && interpreter == null) {
final handCannedGestureInputTensors = _interpreter.getInputTensors();
for (final tensor in outputHandCannedGestureTensors) {
Logger.d('Hand Detector Output Tensor: $tensor');
Expand All @@ -70,7 +71,7 @@ class HandCannedGestureClassifier

stopwatch.stop();
final processModelTime = stopwatch.elapsedMilliseconds;
if (_logResultTime) {
if (Config.logMlHandlersVerbose) {
Logger.d('processModelTime: $processModelTime');
}

Expand Down
Loading

0 comments on commit d2c2edf

Please sign in to comment.