Skip to content

Commit

Permalink
Ensure reversed text matches original
Browse files Browse the repository at this point in the history
Previously, we optimized by calling removeDiacritics + toLowerCase on the entire input string.

However, this leads to the text returned for an embedding being the lowercase + diacritics removed version.

This is unacceptable for use cases where the text is used for RAG, i.e. we want to send the original text to the LLM, not the lowercase version. It doesn't seem to affect the LLM, however, it looks weird in app UI showing what was sent.
  • Loading branch information
jpohhhh committed Aug 28, 2024
1 parent 1cdf665 commit a1d5e3d
Showing 1 changed file with 45 additions and 48 deletions.
93 changes: 45 additions & 48 deletions lib/tokenizers/wordpiece_tokenizer.dart
Original file line number Diff line number Diff line change
Expand Up @@ -49,53 +49,6 @@ class WordpieceTokenizer {
(0x30A0 <= codeUnit && codeUnit <= 0x30FF);
}

List<TextAndTokens> _createSubstrings(String text, {int? maxTokens}) {
final max = maxTokens ?? maxInputTokens;
text = text.trim();
if (text.isEmpty) {
return [
TextAndTokens(text: '', tokens: [101, 102])
];
}
text = removeDiacritics(text);
text = text.toLowerCase();

List<List<int>> allOutputTokens = [];
List<String> allOutputStrings = [];
List<String> words = text.split(RegExp(r'\s+')); // Split on whitespace

List<int> outputTokens = [startToken];
List<String> outputString = [];
for (final word in words) {
if (word.length > maxInputCharsPerWord) {
continue;
}

List<int> wordTokens = _tokenizeWord(word);

if (outputTokens.length + wordTokens.length >= max - 1) {
outputTokens.add(endToken);
allOutputStrings.add(outputString.join(' '));
allOutputTokens.add(outputTokens);
outputString = [word];
outputTokens = [startToken, ...wordTokens];
} else {
outputString.add(word);
outputTokens.addAll(wordTokens);
}
}
outputTokens.add(endToken);
allOutputTokens.add(outputTokens);
allOutputStrings.add(outputString.join(' '));
assert(allOutputStrings.length == allOutputTokens.length);
return List<TextAndTokens>.generate(allOutputStrings.length, (index) {
return TextAndTokens(
text: allOutputStrings[index],
tokens: allOutputTokens[index],
);
});
}

List<int> _tokenizeWord(String word) {
List<int> wordTokens = [];
int start = 0;
Expand Down Expand Up @@ -136,7 +89,51 @@ class WordpieceTokenizer {
}

List<TextAndTokens> tokenize(String text, {int? maxTokens}) {
return _createSubstrings(text, maxTokens: maxTokens);
final max = maxTokens ?? maxInputTokens;
text = text.trim();
if (text.isEmpty) {
return [
TextAndTokens(text: '', tokens: [startToken, endToken])
];
}

List<List<int>> allOutputTokens = [];
List<String> allOutputStrings = [];
List<String> words = text.split(RegExp(r'\s+')); // Split on whitespace

List<int> outputTokens = [startToken];
List<String> outputString = [];
String normalizedWord = '';

for (final word in words) {
if (word.length > maxInputCharsPerWord) {
continue;
}

normalizedWord = removeDiacritics(word.toLowerCase());
List<int> wordTokens = _tokenizeWord(normalizedWord);

if (outputTokens.length + wordTokens.length >= max - 1) {
outputTokens.add(endToken);
allOutputStrings.add(outputString.join(' '));
allOutputTokens.add(outputTokens);
outputString = [word];
outputTokens = [startToken, ...wordTokens];
} else {
outputString.add(word);
outputTokens.addAll(wordTokens);
}
}
outputTokens.add(endToken);
allOutputTokens.add(outputTokens);
allOutputStrings.add(outputString.join(' '));

return List<TextAndTokens>.generate(allOutputStrings.length, (index) {
return TextAndTokens(
text: allOutputStrings[index],
tokens: allOutputTokens[index],
);
});
}

String detokenize(List<int> tokens) {
Expand Down

0 comments on commit a1d5e3d

Please sign in to comment.