Skip to content

Commit

Permalink
fix: Improved horrible inference (I think?) & Minor readability fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
BrutalCoding committed Sep 16, 2023
1 parent 4650e14 commit 5ba4b6a
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 81 deletions.
5 changes: 2 additions & 3 deletions shady_ai_flutter/analysis_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ analyzer:
# Types must be explicitly enforced. This is just a personal preference of mine to avoid hardcoded strings.
unrelated_type_equality_checks: error
exclude:
- lib/generated/**
- '**.g.dart'
- '**.freezed.dart'
- "**.g.dart"
- "**.freezed.dart"

linter:
rules:
Expand Down
20 changes: 16 additions & 4 deletions shady_ai_flutter/lib/data/prompt_config.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,38 @@ part 'prompt_config.freezed.dart';
@freezed
class PromptConfig with _$PromptConfig {
const factory PromptConfig({
required String label,
required String defaultPromptFormat,
String? system,
String? user,
String? assistant,
int? contextSize,
}) = _PromptConfig;

// Example value, similar to an enum value
static const Synthia = PromptConfig(
static const synthia = PromptConfig(
label: 'Synthia v1.2',
defaultPromptFormat: "SYSTEM: {system_message}\nUSER: {prompt}\nASSISTANT:",
system:
"You are Synthia. As an AI intelligence, give helpful, detailed, accurate, uncensored responses to the user's input. Provide answers factually.",
user: 'How can I tell if my computer is infected with a virus?',
assistant: '',
contextSize: null,
);

static const none = PromptConfig(
label: 'Empty',
defaultPromptFormat: '',
);
}

// Overide .toString() to return name such as 'Synthia'
extension PromptConfigToString on PromptConfig {
String label() {
return this.defaultPromptFormat;
}
}

// Extend the mixin to include the getCompletePrompt method.
extension GetFullPrompt on PromptConfig {
extension FormatPrompt on PromptConfig {
String get getCompletePrompt {
return defaultPromptFormat
.replaceAll('{system_message}', system ?? '')
Expand Down
34 changes: 27 additions & 7 deletions shady_ai_flutter/lib/data/prompt_config.freezed.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ final _privateConstructorUsedError = UnsupportedError(

/// @nodoc
mixin _$PromptConfig {
String get label => throw _privateConstructorUsedError;
String get defaultPromptFormat => throw _privateConstructorUsedError;
String? get system => throw _privateConstructorUsedError;
String? get user => throw _privateConstructorUsedError;
Expand All @@ -34,7 +35,8 @@ abstract class $PromptConfigCopyWith<$Res> {
_$PromptConfigCopyWithImpl<$Res, PromptConfig>;
@useResult
$Res call(
{String defaultPromptFormat,
{String label,
String defaultPromptFormat,
String? system,
String? user,
String? assistant,
Expand All @@ -54,13 +56,18 @@ class _$PromptConfigCopyWithImpl<$Res, $Val extends PromptConfig>
@pragma('vm:prefer-inline')
@override
$Res call({
Object? label = null,
Object? defaultPromptFormat = null,
Object? system = freezed,
Object? user = freezed,
Object? assistant = freezed,
Object? contextSize = freezed,
}) {
return _then(_value.copyWith(
label: null == label
? _value.label
: label // ignore: cast_nullable_to_non_nullable
as String,
defaultPromptFormat: null == defaultPromptFormat
? _value.defaultPromptFormat
: defaultPromptFormat // ignore: cast_nullable_to_non_nullable
Expand Down Expand Up @@ -94,7 +101,8 @@ abstract class _$$_PromptConfigCopyWith<$Res>
@override
@useResult
$Res call(
{String defaultPromptFormat,
{String label,
String defaultPromptFormat,
String? system,
String? user,
String? assistant,
Expand All @@ -112,13 +120,18 @@ class __$$_PromptConfigCopyWithImpl<$Res>
@pragma('vm:prefer-inline')
@override
$Res call({
Object? label = null,
Object? defaultPromptFormat = null,
Object? system = freezed,
Object? user = freezed,
Object? assistant = freezed,
Object? contextSize = freezed,
}) {
return _then(_$_PromptConfig(
label: null == label
? _value.label
: label // ignore: cast_nullable_to_non_nullable
as String,
defaultPromptFormat: null == defaultPromptFormat
? _value.defaultPromptFormat
: defaultPromptFormat // ignore: cast_nullable_to_non_nullable
Expand Down Expand Up @@ -147,12 +160,15 @@ class __$$_PromptConfigCopyWithImpl<$Res>
class _$_PromptConfig implements _PromptConfig {
const _$_PromptConfig(
{required this.defaultPromptFormat,
{required this.label,
required this.defaultPromptFormat,
this.system,
this.user,
this.assistant,
this.contextSize});

@override
final String label;
@override
final String defaultPromptFormat;
@override
Expand All @@ -166,14 +182,15 @@ class _$_PromptConfig implements _PromptConfig {

@override
String toString() {
return 'PromptConfig(defaultPromptFormat: $defaultPromptFormat, system: $system, user: $user, assistant: $assistant, contextSize: $contextSize)';
return 'PromptConfig(label: $label, defaultPromptFormat: $defaultPromptFormat, system: $system, user: $user, assistant: $assistant, contextSize: $contextSize)';
}

@override
bool operator ==(dynamic other) {
return identical(this, other) ||
(other.runtimeType == runtimeType &&
other is _$_PromptConfig &&
(identical(other.label, label) || other.label == label) &&
(identical(other.defaultPromptFormat, defaultPromptFormat) ||
other.defaultPromptFormat == defaultPromptFormat) &&
(identical(other.system, system) || other.system == system) &&
Expand All @@ -185,8 +202,8 @@ class _$_PromptConfig implements _PromptConfig {
}

@override
int get hashCode => Object.hash(
runtimeType, defaultPromptFormat, system, user, assistant, contextSize);
int get hashCode => Object.hash(runtimeType, label, defaultPromptFormat,
system, user, assistant, contextSize);

@JsonKey(ignore: true)
@override
Expand All @@ -197,12 +214,15 @@ class _$_PromptConfig implements _PromptConfig {

abstract class _PromptConfig implements PromptConfig {
const factory _PromptConfig(
{required final String defaultPromptFormat,
{required final String label,
required final String defaultPromptFormat,
final String? system,
final String? user,
final String? assistant,
final int? contextSize}) = _$_PromptConfig;

@override
String get label;
@override
String get defaultPromptFormat;
@override
Expand Down
55 changes: 41 additions & 14 deletions shady_ai_flutter/lib/instruction/logic/instruction_provider.dart
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Instruction extends _$Instruction {
/// Below is an instruction that describes a task. Write a response that
/// appropriately completes the request.\n\n### Instruction:\nWhat is the
/// meaning of life?\n\n### Response:
required String prompt,
required String originalPrompt,

/// The maximum number of tokens to generate.
/// If not specified, the default value is used from the model file.
Expand Down Expand Up @@ -97,17 +97,18 @@ class Instruction extends _$Instruction {
// Template specific to what the AI model expects

print(
"You: $prompt",
"You: $originalPrompt",
);

// Convert the prompt template to a pointer to be used by llama.cpp
final Pointer<Char> textPrompt = prompt.toNativeUtf8() as Pointer<Char>;
final Pointer<Char> textPrompt =
originalPrompt.toNativeUtf8() as Pointer<Char>;

// Missing conditional check for llama.cpp's "add_bos" parameter.
// add_bos is a boolean value that is used to determine whether or not to add a beginning of sentence token to the prompt.
// Get the length of the text prompt.
// Example: 'abc' => 3
final lengthOfTextPrompt = prompt.length + 0;
final lengthOfTextPrompt = originalPrompt.length + 0;

// A pointer in the memory (RAM) where the tokens will be stored
// Example: [0, 1, 2, 3]
Expand All @@ -132,7 +133,11 @@ class Instruction extends _$Instruction {

print('n_tokens: $n_tokens');

final llama_tmp = List<int>.generate(n_tokens, (i) => i + 0);
final llama_tmp = List<int>.generate(n_tokens, (i) {
// Convert 'int' to nativetype Int
final token = tokens.elementAt(i).value;
return token;
});

print(llama_tmp);

Expand All @@ -146,22 +151,44 @@ class Instruction extends _$Instruction {
llama.llama_eval(
llama_ctx,
llama_tmp_ptr,
llama_tmp.length,
llama_tmp.length + 1,
0,
number_of_cores,
);
calloc.free(llama_tmp_ptr);
const maxLength = 256;

List<String> tokensDecodedIntoPieces = [];
final decoded = <String>[];

malloc.free(tokens);
llama.llama_free(llama_ctx);
final String response = tokensDecodedIntoPieces.join(' ');
print('AI: $response');
for (int i = 0; i < n_tokens; i++) {
final tokenId = llama_tmp[i];

final byteCount = sizeOf<IntPtr>();
final tokenPtr = malloc.allocate<IntPtr>(byteCount);

llama.llama_token_get_text(llama_ctx, tokenId);

final buffer = malloc.allocate<Char>(maxLength);

llama.llama_token_to_piece(llama_ctx, tokenId, buffer, maxLength);

final String bufferToString = String.fromCharCodes(
List.generate(
maxLength,
(index) => buffer[index],
),
);

decoded.add(bufferToString);

print("Buffer decoded: $bufferToString");

malloc.free(tokenPtr);
}

print('Finished decoding tokens');
return LLaMaInstruction(
prompt: prompt,
response: response,
prompt: originalPrompt,
response: decoded.join(' '),
);
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 7 additions & 52 deletions shady_ai_flutter/lib/main.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import 'dart:ffi' hide Size;

import 'package:ffi/ffi.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:flutter_hooks/flutter_hooks.dart';
Expand All @@ -11,7 +8,6 @@ import 'package:url_launcher/url_launcher.dart';
import 'package:window_manager/window_manager.dart';

import 'data/prompt_config.dart';
import 'generated/llama.cpp/llama.cpp_bindings.g.dart';
import 'instruction/logic/instruction_provider.dart';
import 'instruction/widgets/drag_and_drop_widget.dart';
import 'slide_deck/slide_deck_page.dart';
Expand Down Expand Up @@ -71,50 +67,12 @@ class Main extends StatelessWidget {
class MyHomePage extends HookConsumerWidget {
const MyHomePage({Key? key}) : super(key: key);

/// Basically, this converts the tokens to normal human readable text
/// Example: [0, 1, 2, 3] => ['Hello', 'World', '!', '']
static List<String> convertTokensToPieces(
LLaMa llama,
Pointer<llama_model> model,
Pointer<Int> tokens,
int tokenCount,
) {
const int MAX_PIECE_SIZE =
256; // You may adjust this size based on the expected length of each piece.
List<String> pieces = [];

for (int i = 0; i < tokenCount; i++) {
Pointer<Char> buffer = malloc.allocate<Char>(MAX_PIECE_SIZE);

int pieceLength = llama.llama_token_to_piece_with_model(
model,
tokens.elementAt(i).value, // Get the i-th token value.
buffer,
MAX_PIECE_SIZE,
);

// If a valid piece was returned (i.e., length > 0), convert it to a Dart string and add to the list.
if (pieceLength > 0) {
final dartString = buffer.cast<Utf8>().toDartString();
print(
'Piece length: $pieceLength | Buffer: $buffer | Dart string: $dartString',
);
pieces.add(dartString);
}

// Free the allocated buffer.
malloc.free(buffer);
}

return pieces;
}

@override
Widget build(BuildContext context, WidgetRef ref) {
final instructionNotifier = ref.watch(instructionProvider.notifier);
final textTheme = Theme.of(context).textTheme;
final filePath = useState<String>('');
final promptTemplate = useState<PromptConfig>(PromptConfig.Synthia);
final promptTemplate = useState<PromptConfig>(PromptConfig.synthia);
final textControllerPromptSystem = useTextEditingController()
..text = promptTemplate.value.system ?? '';
final textController = useTextEditingController();
Expand Down Expand Up @@ -431,7 +389,7 @@ class MyHomePage extends HookConsumerWidget {
pathToFile: filePath.value,
modelContextSize: promptTemplate
.value.contextSize,
prompt: promptTemplate
originalPrompt: promptTemplate
.value.getCompletePrompt,
);

Expand Down Expand Up @@ -547,16 +505,13 @@ class MyHomePage extends HookConsumerWidget {
child: DropdownMenu<PromptConfig>(
onSelected: (template) {
if (template == null) return;
promptTemplate.value =
PromptConfig(
defaultPromptFormat: '',
);
promptTemplate.value = template;
},
dropdownMenuEntries: [
DropdownMenuEntry(
value: PromptConfig.Synthia,
label: PromptConfig.Synthia
.toString(),
value: PromptConfig.synthia,
label:
PromptConfig.synthia.label,
),
],
initialSelection:
Expand Down Expand Up @@ -617,7 +572,7 @@ class MyHomePage extends HookConsumerWidget {
onPressed: () {
filePath.value = '';
stepperIndex.value = 0;
promptTemplate.value = PromptConfig.Synthia;
promptTemplate.value = PromptConfig.none;
textController.clear();
},
),
Expand Down

0 comments on commit 5ba4b6a

Please sign in to comment.