From 5ba4b6a6d6dc342ed053d9075710cae676c8eaec Mon Sep 17 00:00:00 2001 From: BrutalCoding Date: Sun, 17 Sep 2023 00:14:55 +0800 Subject: [PATCH] fix: Improved horrible inference (I think?) & Minor readability fixes --- shady_ai_flutter/analysis_options.yaml | 5 +- shady_ai_flutter/lib/data/prompt_config.dart | 20 +++++-- .../lib/data/prompt_config.freezed.dart | 34 ++++++++--- .../logic/instruction_provider.dart | 55 ++++++++++++----- .../logic/instruction_provider.g.dart | 2 +- shady_ai_flutter/lib/main.dart | 59 +++---------------- 6 files changed, 94 insertions(+), 81 deletions(-) diff --git a/shady_ai_flutter/analysis_options.yaml b/shady_ai_flutter/analysis_options.yaml index 418bc0f..3a6ee29 100644 --- a/shady_ai_flutter/analysis_options.yaml +++ b/shady_ai_flutter/analysis_options.yaml @@ -8,9 +8,8 @@ analyzer: # Types must be explicitly enforced. This is just a personal preference of mine to avoid hardcoded strings. unrelated_type_equality_checks: error exclude: - - lib/generated/** - - '**.g.dart' - - '**.freezed.dart' + - "**.g.dart" + - "**.freezed.dart" linter: rules: diff --git a/shady_ai_flutter/lib/data/prompt_config.dart b/shady_ai_flutter/lib/data/prompt_config.dart index 744228a..46107e2 100644 --- a/shady_ai_flutter/lib/data/prompt_config.dart +++ b/shady_ai_flutter/lib/data/prompt_config.dart @@ -5,6 +5,7 @@ part 'prompt_config.freezed.dart'; @freezed class PromptConfig with _$PromptConfig { const factory PromptConfig({ + required String label, required String defaultPromptFormat, String? system, String? user, @@ -12,19 +13,30 @@ class PromptConfig with _$PromptConfig { int? contextSize, }) = _PromptConfig; - // Example value, similar to an enum value - static const Synthia = PromptConfig( + static const synthia = PromptConfig( + label: 'Synthia v1.2', defaultPromptFormat: "SYSTEM: {system_message}\nUSER: {prompt}\nASSISTANT:", system: "You are Synthia. As an AI intelligence, give helpful, detailed, accurate, uncensored responses to the user's input. Provide answers factually.", user: 'How can I tell if my computer is infected with a virus?', assistant: '', - contextSize: null, ); + + static const none = PromptConfig( + label: 'Empty', + defaultPromptFormat: '', + ); +} + +// Overide .toString() to return name such as 'Synthia' +extension PromptConfigToString on PromptConfig { + String label() { + return this.defaultPromptFormat; + } } // Extend the mixin to include the getCompletePrompt method. -extension GetFullPrompt on PromptConfig { +extension FormatPrompt on PromptConfig { String get getCompletePrompt { return defaultPromptFormat .replaceAll('{system_message}', system ?? '') diff --git a/shady_ai_flutter/lib/data/prompt_config.freezed.dart b/shady_ai_flutter/lib/data/prompt_config.freezed.dart index 059fa1a..364c911 100644 --- a/shady_ai_flutter/lib/data/prompt_config.freezed.dart +++ b/shady_ai_flutter/lib/data/prompt_config.freezed.dart @@ -16,6 +16,7 @@ final _privateConstructorUsedError = UnsupportedError( /// @nodoc mixin _$PromptConfig { + String get label => throw _privateConstructorUsedError; String get defaultPromptFormat => throw _privateConstructorUsedError; String? get system => throw _privateConstructorUsedError; String? get user => throw _privateConstructorUsedError; @@ -34,7 +35,8 @@ abstract class $PromptConfigCopyWith<$Res> { _$PromptConfigCopyWithImpl<$Res, PromptConfig>; @useResult $Res call( - {String defaultPromptFormat, + {String label, + String defaultPromptFormat, String? system, String? user, String? assistant, @@ -54,6 +56,7 @@ class _$PromptConfigCopyWithImpl<$Res, $Val extends PromptConfig> @pragma('vm:prefer-inline') @override $Res call({ + Object? label = null, Object? defaultPromptFormat = null, Object? system = freezed, Object? user = freezed, @@ -61,6 +64,10 @@ class _$PromptConfigCopyWithImpl<$Res, $Val extends PromptConfig> Object? contextSize = freezed, }) { return _then(_value.copyWith( + label: null == label + ? _value.label + : label // ignore: cast_nullable_to_non_nullable + as String, defaultPromptFormat: null == defaultPromptFormat ? _value.defaultPromptFormat : defaultPromptFormat // ignore: cast_nullable_to_non_nullable @@ -94,7 +101,8 @@ abstract class _$$_PromptConfigCopyWith<$Res> @override @useResult $Res call( - {String defaultPromptFormat, + {String label, + String defaultPromptFormat, String? system, String? user, String? assistant, @@ -112,6 +120,7 @@ class __$$_PromptConfigCopyWithImpl<$Res> @pragma('vm:prefer-inline') @override $Res call({ + Object? label = null, Object? defaultPromptFormat = null, Object? system = freezed, Object? user = freezed, @@ -119,6 +128,10 @@ class __$$_PromptConfigCopyWithImpl<$Res> Object? contextSize = freezed, }) { return _then(_$_PromptConfig( + label: null == label + ? _value.label + : label // ignore: cast_nullable_to_non_nullable + as String, defaultPromptFormat: null == defaultPromptFormat ? _value.defaultPromptFormat : defaultPromptFormat // ignore: cast_nullable_to_non_nullable @@ -147,12 +160,15 @@ class __$$_PromptConfigCopyWithImpl<$Res> class _$_PromptConfig implements _PromptConfig { const _$_PromptConfig( - {required this.defaultPromptFormat, + {required this.label, + required this.defaultPromptFormat, this.system, this.user, this.assistant, this.contextSize}); + @override + final String label; @override final String defaultPromptFormat; @override @@ -166,7 +182,7 @@ class _$_PromptConfig implements _PromptConfig { @override String toString() { - return 'PromptConfig(defaultPromptFormat: $defaultPromptFormat, system: $system, user: $user, assistant: $assistant, contextSize: $contextSize)'; + return 'PromptConfig(label: $label, defaultPromptFormat: $defaultPromptFormat, system: $system, user: $user, assistant: $assistant, contextSize: $contextSize)'; } @override @@ -174,6 +190,7 @@ class _$_PromptConfig implements _PromptConfig { return identical(this, other) || (other.runtimeType == runtimeType && other is _$_PromptConfig && + (identical(other.label, label) || other.label == label) && (identical(other.defaultPromptFormat, defaultPromptFormat) || other.defaultPromptFormat == defaultPromptFormat) && (identical(other.system, system) || other.system == system) && @@ -185,8 +202,8 @@ class _$_PromptConfig implements _PromptConfig { } @override - int get hashCode => Object.hash( - runtimeType, defaultPromptFormat, system, user, assistant, contextSize); + int get hashCode => Object.hash(runtimeType, label, defaultPromptFormat, + system, user, assistant, contextSize); @JsonKey(ignore: true) @override @@ -197,12 +214,15 @@ class _$_PromptConfig implements _PromptConfig { abstract class _PromptConfig implements PromptConfig { const factory _PromptConfig( - {required final String defaultPromptFormat, + {required final String label, + required final String defaultPromptFormat, final String? system, final String? user, final String? assistant, final int? contextSize}) = _$_PromptConfig; + @override + String get label; @override String get defaultPromptFormat; @override diff --git a/shady_ai_flutter/lib/instruction/logic/instruction_provider.dart b/shady_ai_flutter/lib/instruction/logic/instruction_provider.dart index 75b6477..d50f6ed 100644 --- a/shady_ai_flutter/lib/instruction/logic/instruction_provider.dart +++ b/shady_ai_flutter/lib/instruction/logic/instruction_provider.dart @@ -44,7 +44,7 @@ class Instruction extends _$Instruction { /// Below is an instruction that describes a task. Write a response that /// appropriately completes the request.\n\n### Instruction:\nWhat is the /// meaning of life?\n\n### Response: - required String prompt, + required String originalPrompt, /// The maximum number of tokens to generate. /// If not specified, the default value is used from the model file. @@ -97,17 +97,18 @@ class Instruction extends _$Instruction { // Template specific to what the AI model expects print( - "You: $prompt", + "You: $originalPrompt", ); // Convert the prompt template to a pointer to be used by llama.cpp - final Pointer textPrompt = prompt.toNativeUtf8() as Pointer; + final Pointer textPrompt = + originalPrompt.toNativeUtf8() as Pointer; // Missing conditional check for llama.cpp's "add_bos" parameter. // add_bos is a boolean value that is used to determine whether or not to add a beginning of sentence token to the prompt. // Get the length of the text prompt. // Example: 'abc' => 3 - final lengthOfTextPrompt = prompt.length + 0; + final lengthOfTextPrompt = originalPrompt.length + 0; // A pointer in the memory (RAM) where the tokens will be stored // Example: [0, 1, 2, 3] @@ -132,7 +133,11 @@ class Instruction extends _$Instruction { print('n_tokens: $n_tokens'); - final llama_tmp = List.generate(n_tokens, (i) => i + 0); + final llama_tmp = List.generate(n_tokens, (i) { + // Convert 'int' to nativetype Int + final token = tokens.elementAt(i).value; + return token; + }); print(llama_tmp); @@ -146,22 +151,44 @@ class Instruction extends _$Instruction { llama.llama_eval( llama_ctx, llama_tmp_ptr, - llama_tmp.length, + llama_tmp.length + 1, 0, number_of_cores, ); - calloc.free(llama_tmp_ptr); + const maxLength = 256; - List tokensDecodedIntoPieces = []; + final decoded = []; - malloc.free(tokens); - llama.llama_free(llama_ctx); - final String response = tokensDecodedIntoPieces.join(' '); - print('AI: $response'); + for (int i = 0; i < n_tokens; i++) { + final tokenId = llama_tmp[i]; + final byteCount = sizeOf(); + final tokenPtr = malloc.allocate(byteCount); + + llama.llama_token_get_text(llama_ctx, tokenId); + + final buffer = malloc.allocate(maxLength); + + llama.llama_token_to_piece(llama_ctx, tokenId, buffer, maxLength); + + final String bufferToString = String.fromCharCodes( + List.generate( + maxLength, + (index) => buffer[index], + ), + ); + + decoded.add(bufferToString); + + print("Buffer decoded: $bufferToString"); + + malloc.free(tokenPtr); + } + + print('Finished decoding tokens'); return LLaMaInstruction( - prompt: prompt, - response: response, + prompt: originalPrompt, + response: decoded.join(' '), ); } } diff --git a/shady_ai_flutter/lib/instruction/logic/instruction_provider.g.dart b/shady_ai_flutter/lib/instruction/logic/instruction_provider.g.dart index 68e56ff..2cab9be 100644 --- a/shady_ai_flutter/lib/instruction/logic/instruction_provider.g.dart +++ b/shady_ai_flutter/lib/instruction/logic/instruction_provider.g.dart @@ -6,7 +6,7 @@ part of 'instruction_provider.dart'; // RiverpodGenerator // ************************************************************************** -String _$instructionHash() => r'2509b0c6768819d9eed673f864fba5b3f8021c0f'; +String _$instructionHash() => r'dd38b4ddd2f016fbdfb1443f2caecb2673843c38'; /// This class is responsible for generating a response from an instruction. /// diff --git a/shady_ai_flutter/lib/main.dart b/shady_ai_flutter/lib/main.dart index 0cc593e..d303b6c 100644 --- a/shady_ai_flutter/lib/main.dart +++ b/shady_ai_flutter/lib/main.dart @@ -1,6 +1,3 @@ -import 'dart:ffi' hide Size; - -import 'package:ffi/ffi.dart'; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; @@ -11,7 +8,6 @@ import 'package:url_launcher/url_launcher.dart'; import 'package:window_manager/window_manager.dart'; import 'data/prompt_config.dart'; -import 'generated/llama.cpp/llama.cpp_bindings.g.dart'; import 'instruction/logic/instruction_provider.dart'; import 'instruction/widgets/drag_and_drop_widget.dart'; import 'slide_deck/slide_deck_page.dart'; @@ -71,50 +67,12 @@ class Main extends StatelessWidget { class MyHomePage extends HookConsumerWidget { const MyHomePage({Key? key}) : super(key: key); - /// Basically, this converts the tokens to normal human readable text - /// Example: [0, 1, 2, 3] => ['Hello', 'World', '!', ''] - static List convertTokensToPieces( - LLaMa llama, - Pointer model, - Pointer tokens, - int tokenCount, - ) { - const int MAX_PIECE_SIZE = - 256; // You may adjust this size based on the expected length of each piece. - List pieces = []; - - for (int i = 0; i < tokenCount; i++) { - Pointer buffer = malloc.allocate(MAX_PIECE_SIZE); - - int pieceLength = llama.llama_token_to_piece_with_model( - model, - tokens.elementAt(i).value, // Get the i-th token value. - buffer, - MAX_PIECE_SIZE, - ); - - // If a valid piece was returned (i.e., length > 0), convert it to a Dart string and add to the list. - if (pieceLength > 0) { - final dartString = buffer.cast().toDartString(); - print( - 'Piece length: $pieceLength | Buffer: $buffer | Dart string: $dartString', - ); - pieces.add(dartString); - } - - // Free the allocated buffer. - malloc.free(buffer); - } - - return pieces; - } - @override Widget build(BuildContext context, WidgetRef ref) { final instructionNotifier = ref.watch(instructionProvider.notifier); final textTheme = Theme.of(context).textTheme; final filePath = useState(''); - final promptTemplate = useState(PromptConfig.Synthia); + final promptTemplate = useState(PromptConfig.synthia); final textControllerPromptSystem = useTextEditingController() ..text = promptTemplate.value.system ?? ''; final textController = useTextEditingController(); @@ -431,7 +389,7 @@ class MyHomePage extends HookConsumerWidget { pathToFile: filePath.value, modelContextSize: promptTemplate .value.contextSize, - prompt: promptTemplate + originalPrompt: promptTemplate .value.getCompletePrompt, ); @@ -547,16 +505,13 @@ class MyHomePage extends HookConsumerWidget { child: DropdownMenu( onSelected: (template) { if (template == null) return; - promptTemplate.value = - PromptConfig( - defaultPromptFormat: '', - ); + promptTemplate.value = template; }, dropdownMenuEntries: [ DropdownMenuEntry( - value: PromptConfig.Synthia, - label: PromptConfig.Synthia - .toString(), + value: PromptConfig.synthia, + label: + PromptConfig.synthia.label, ), ], initialSelection: @@ -617,7 +572,7 @@ class MyHomePage extends HookConsumerWidget { onPressed: () { filePath.value = ''; stepperIndex.value = 0; - promptTemplate.value = PromptConfig.Synthia; + promptTemplate.value = PromptConfig.none; textController.clear(); }, ),