diff --git a/Cargo.lock b/Cargo.lock index 5d02727..59d721a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -378,6 +378,7 @@ dependencies = [ "ndarray", "ndarray-rand", "serde", + "serde-wasm-bindgen", "serde_json", "tokenizers", "wasm-bindgen", @@ -595,6 +596,17 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-wasm-bindgen" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30c9933e5689bd420dc6c87b7a1835701810cbc10cd86a26e4da45b73e6b1d78" +dependencies = [ + "js-sys", + "serde", + "wasm-bindgen", +] + [[package]] name = "serde_derive" version = "1.0.181" diff --git a/crates/tokenizers/Cargo.toml b/crates/tokenizers/Cargo.toml index 9f67c05..d541faf 100644 --- a/crates/tokenizers/Cargo.toml +++ b/crates/tokenizers/Cargo.toml @@ -11,6 +11,7 @@ ndarray = "0.15.6" ndarray-rand = "0.14.0" serde = {version = "1.0", features = ["derive"]} serde_json = "1.0" +serde-wasm-bindgen = "0.6.0" tokenizers = { version="0.14.1", default-features=false, features = ["unstable_wasm"]} wasm-bindgen = "=0.2.84" getrandom = { version = "0.2", features = ["js"] } diff --git a/crates/tokenizers/src/wasm.rs b/crates/tokenizers/src/wasm.rs index 0e1dc0a..916b223 100644 --- a/crates/tokenizers/src/wasm.rs +++ b/crates/tokenizers/src/wasm.rs @@ -1,5 +1,5 @@ use crate::RESOURCES; -use std::str::FromStr; +use std::{collections::HashMap, str::FromStr}; use tokenizers::{models::bpe::BPE, tokenizer::Tokenizer}; use wasm_bindgen::prelude::*; @@ -39,7 +39,7 @@ pub fn wasm_bpe_default() -> usize { } #[wasm_bindgen] -pub fn wasm_tokenizer_tokenize(id: usize, string: String) -> Vec { +pub fn wasm_tokenizer_encode(id: usize, string: String) -> Vec { let mut data: Vec = Vec::new(); RESOURCES.with(|cell| { let tokenizers = cell.tokenizer.borrow_mut(); @@ -53,3 +53,53 @@ pub fn wasm_tokenizer_tokenize(id: usize, string: String) -> Vec { }); data } + +#[wasm_bindgen] +pub fn wasm_tokenizer_get_vocab(id: usize, with_added_tokens: bool) -> JsValue { + let mut data: HashMap = HashMap::new(); + RESOURCES.with(|cell| { + let tokenizers = cell.tokenizer.borrow_mut(); + data = tokenizers[id].get_vocab(with_added_tokens) + }); + serde_wasm_bindgen::to_value(&data).unwrap() +} + +#[wasm_bindgen] +pub fn wasm_tokenizer_get_vocab_size(id: usize, with_added_tokens: bool) -> usize { + let mut data: usize = 0; + RESOURCES.with(|cell| { + let tokenizers = cell.tokenizer.borrow_mut(); + data = tokenizers[id].get_vocab_size(with_added_tokens) + }); + data +} + +#[wasm_bindgen] +pub fn wasm_tokenizer_decode(id: usize, ids: &[u32], skip_special_tokens: bool) -> String { + let mut data: String = String::new(); + RESOURCES.with(|cell| { + let tokenizers = cell.tokenizer.borrow_mut(); + data = tokenizers[id].decode(ids, skip_special_tokens).unwrap() + }); + data +} + +#[wasm_bindgen] +pub fn wasm_tokenizer_token_to_id(id: usize, token: String) -> u32 { + let mut data: u32 = 0; + RESOURCES.with(|cell| { + let tokenizers = cell.tokenizer.borrow_mut(); + data = tokenizers[id].token_to_id(token.as_str()).unwrap() + }); + data +} + +#[wasm_bindgen] +pub fn wasm_tokenizer_id_to_token(id: usize, token_id: u32) -> String { + let mut data: String = String::new(); + RESOURCES.with(|cell| { + let tokenizers = cell.tokenizer.borrow_mut(); + data = tokenizers[id].id_to_token(token_id).unwrap() + }); + data +} \ No newline at end of file diff --git a/examples/tokenizers/basic.ts b/examples/tokenizers/basic.ts index 4245000..e7b8a42 100644 --- a/examples/tokenizers/basic.ts +++ b/examples/tokenizers/basic.ts @@ -2,10 +2,13 @@ import { init, Tokenizer } from "../../tokenizers/mod.ts"; await init(); -const tokenizer = Tokenizer.fromJson( +const tokenizer = Tokenizer.fromJSON( await (await fetch( `https://huggingface.co/satvikag/chatbot/resolve/main/tokenizer.json`, )).text(), ); -console.log(tokenizer.tokenize("Hello World!")); \ No newline at end of file +const encoded = tokenizer.encode("Hello World!"); +console.log(encoded); +const decoded = tokenizer.decode(encoded); +console.log(decoded); \ No newline at end of file diff --git a/tokenizers/lib/netsaur_tokenizers.generated.js b/tokenizers/lib/netsaur_tokenizers.generated.js index 66bab9a..03d84f2 100644 --- a/tokenizers/lib/netsaur_tokenizers.generated.js +++ b/tokenizers/lib/netsaur_tokenizers.generated.js @@ -1,7 +1,7 @@ // @generated file from wasmbuild -- do not edit // deno-lint-ignore-file // deno-fmt-ignore-file -// source-hash: 78d396cb1d4bd48d3b906442e5d99b03fe98f891 +// source-hash: 8beb82a36802cbe280139e22e69a33415c0cb780 let wasm; const heap = new Array(128).fill(undefined); @@ -55,6 +55,71 @@ function addHeapObject(obj) { return idx; } +function debugString(val) { + // primitive types + const type = typeof val; + if (type == "number" || type == "boolean" || val == null) { + return `${val}`; + } + if (type == "string") { + return `"${val}"`; + } + if (type == "symbol") { + const description = val.description; + if (description == null) { + return "Symbol"; + } else { + return `Symbol(${description})`; + } + } + if (type == "function") { + const name = val.name; + if (typeof name == "string" && name.length > 0) { + return `Function(${name})`; + } else { + return "Function"; + } + } + // objects + if (Array.isArray(val)) { + const length = val.length; + let debug = "["; + if (length > 0) { + debug += debugString(val[0]); + } + for (let i = 1; i < length; i++) { + debug += ", " + debugString(val[i]); + } + debug += "]"; + return debug; + } + // Test for built-in + const builtInMatches = /\[object ([^\]]+)\]/.exec(toString.call(val)); + let className; + if (builtInMatches.length > 1) { + className = builtInMatches[1]; + } else { + // Failed to match the standard '[object ClassName]' + return toString.call(val); + } + if (className == "Object") { + // we're a user defined class or Object + // JSON.stringify avoids problems with cycles, and is generally much + // easier than looping through ownProperties of `val`. + try { + return "Object(" + JSON.stringify(val) + ")"; + } catch (_) { + return "Object"; + } + } + // errors + if (val instanceof Error) { + return `${val.name}: ${val.message}\n${val.stack}`; + } + // TODO we could test for more things here, like `Set`s and `Map`s. + return className; +} + let WASM_VECTOR_LEN = 0; const cachedTextEncoder = new TextEncoder("utf-8"); @@ -99,6 +164,15 @@ function passStringToWasm0(arg, malloc, realloc) { WASM_VECTOR_LEN = offset; return ptr; } + +let cachedInt32Memory0 = null; + +function getInt32Memory0() { + if (cachedInt32Memory0 === null || cachedInt32Memory0.byteLength === 0) { + cachedInt32Memory0 = new Int32Array(wasm.memory.buffer); + } + return cachedInt32Memory0; +} /** * @param {string} json * @returns {number} @@ -114,14 +188,6 @@ export function wasm_tokenizer_from_json(json) { return ret >>> 0; } -let cachedInt32Memory0 = null; - -function getInt32Memory0() { - if (cachedInt32Memory0 === null || cachedInt32Memory0.byteLength === 0) { - cachedInt32Memory0 = new Int32Array(wasm.memory.buffer); - } - return cachedInt32Memory0; -} /** * @param {number} id * @param {boolean} pretty @@ -165,7 +231,7 @@ function getArrayU32FromWasm0(ptr, len) { * @param {string} string * @returns {Uint32Array} */ -export function wasm_tokenizer_tokenize(id, string) { +export function wasm_tokenizer_encode(id, string) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passStringToWasm0( @@ -174,7 +240,7 @@ export function wasm_tokenizer_tokenize(id, string) { wasm.__wbindgen_realloc, ); const len0 = WASM_VECTOR_LEN; - wasm.wasm_tokenizer_tokenize(retptr, id, ptr0, len0); + wasm.wasm_tokenizer_encode(retptr, id, ptr0, len0); var r0 = getInt32Memory0()[retptr / 4 + 0]; var r1 = getInt32Memory0()[retptr / 4 + 1]; var v1 = getArrayU32FromWasm0(r0, r1).slice(); @@ -185,6 +251,87 @@ export function wasm_tokenizer_tokenize(id, string) { } } +/** + * @param {number} id + * @param {boolean} with_added_tokens + * @returns {any} + */ +export function wasm_tokenizer_get_vocab(id, with_added_tokens) { + const ret = wasm.wasm_tokenizer_get_vocab(id, with_added_tokens); + return takeObject(ret); +} + +/** + * @param {number} id + * @param {boolean} with_added_tokens + * @returns {number} + */ +export function wasm_tokenizer_get_vocab_size(id, with_added_tokens) { + const ret = wasm.wasm_tokenizer_get_vocab_size(id, with_added_tokens); + return ret >>> 0; +} + +function passArray32ToWasm0(arg, malloc) { + const ptr = malloc(arg.length * 4); + getUint32Memory0().set(arg, ptr / 4); + WASM_VECTOR_LEN = arg.length; + return ptr; +} +/** + * @param {number} id + * @param {Uint32Array} ids + * @param {boolean} skip_special_tokens + * @returns {string} + */ +export function wasm_tokenizer_decode(id, ids, skip_special_tokens) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passArray32ToWasm0(ids, wasm.__wbindgen_malloc); + const len0 = WASM_VECTOR_LEN; + wasm.wasm_tokenizer_decode(retptr, id, ptr0, len0, skip_special_tokens); + var r0 = getInt32Memory0()[retptr / 4 + 0]; + var r1 = getInt32Memory0()[retptr / 4 + 1]; + return getStringFromWasm0(r0, r1); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + wasm.__wbindgen_free(r0, r1); + } +} + +/** + * @param {number} id + * @param {string} token + * @returns {number} + */ +export function wasm_tokenizer_token_to_id(id, token) { + const ptr0 = passStringToWasm0( + token, + wasm.__wbindgen_malloc, + wasm.__wbindgen_realloc, + ); + const len0 = WASM_VECTOR_LEN; + const ret = wasm.wasm_tokenizer_token_to_id(id, ptr0, len0); + return ret >>> 0; +} + +/** + * @param {number} id + * @param {number} token_id + * @returns {string} + */ +export function wasm_tokenizer_id_to_token(id, token_id) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.wasm_tokenizer_id_to_token(retptr, id, token_id); + var r0 = getInt32Memory0()[retptr / 4 + 0]; + var r1 = getInt32Memory0()[retptr / 4 + 1]; + return getStringFromWasm0(r0, r1); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + wasm.__wbindgen_free(r0, r1); + } +} + function handleError(f, args) { try { return f.apply(this, args); @@ -195,6 +342,32 @@ function handleError(f, args) { const imports = { __wbindgen_placeholder__: { + __wbindgen_object_drop_ref: function (arg0) { + takeObject(arg0); + }, + __wbindgen_is_string: function (arg0) { + const ret = typeof (getObject(arg0)) === "string"; + return ret; + }, + __wbindgen_error_new: function (arg0, arg1) { + const ret = new Error(getStringFromWasm0(arg0, arg1)); + return addHeapObject(ret); + }, + __wbindgen_number_new: function (arg0) { + const ret = arg0; + return addHeapObject(ret); + }, + __wbindgen_object_clone_ref: function (arg0) { + const ret = getObject(arg0); + return addHeapObject(ret); + }, + __wbindgen_string_new: function (arg0, arg1) { + const ret = getStringFromWasm0(arg0, arg1); + return addHeapObject(ret); + }, + __wbg_set_bd72c078edfa51ad: function (arg0, arg1, arg2) { + getObject(arg0)[takeObject(arg1)] = takeObject(arg2); + }, __wbg_crypto_c48a774b022d20ac: function (arg0) { const ret = getObject(arg0).crypto; return addHeapObject(ret); @@ -216,13 +389,6 @@ const imports = { const ret = getObject(arg0).node; return addHeapObject(ret); }, - __wbindgen_is_string: function (arg0) { - const ret = typeof (getObject(arg0)) === "string"; - return ret; - }, - __wbindgen_object_drop_ref: function (arg0) { - takeObject(arg0); - }, __wbg_msCrypto_bcb970640f50a1e8: function (arg0) { const ret = getObject(arg0).msCrypto; return addHeapObject(ret); @@ -237,10 +403,6 @@ const imports = { const ret = typeof (getObject(arg0)) === "function"; return ret; }, - __wbindgen_string_new: function (arg0, arg1) { - const ret = getStringFromWasm0(arg0, arg1); - return addHeapObject(ret); - }, __wbg_getRandomValues_37fa2ca9e4e07fab: function () { return handleError(function (arg0, arg1) { getObject(arg0).getRandomValues(getObject(arg1)); @@ -255,14 +417,18 @@ const imports = { const ret = new Function(getStringFromWasm0(arg0, arg1)); return addHeapObject(ret); }, + __wbg_new_f841cc6f2098f4b5: function () { + const ret = new Map(); + return addHeapObject(ret); + }, __wbg_call_95d1ea488d03e4e8: function () { return handleError(function (arg0, arg1) { const ret = getObject(arg0).call(getObject(arg1)); return addHeapObject(ret); }, arguments); }, - __wbindgen_object_clone_ref: function (arg0) { - const ret = getObject(arg0); + __wbg_new_f9876326328f45ed: function () { + const ret = new Object(); return addHeapObject(ret); }, __wbg_self_e7c1f827057f6584: function () { @@ -299,6 +465,10 @@ const imports = { return addHeapObject(ret); }, arguments); }, + __wbg_set_388c4c6422704173: function (arg0, arg1, arg2) { + const ret = getObject(arg0).set(getObject(arg1), getObject(arg2)); + return addHeapObject(ret); + }, __wbg_buffer_cf65c07de34b9a08: function (arg0) { const ret = getObject(arg0).buffer; return addHeapObject(ret); @@ -326,6 +496,17 @@ const imports = { const ret = getObject(arg0).subarray(arg1 >>> 0, arg2 >>> 0); return addHeapObject(ret); }, + __wbindgen_debug_string: function (arg0, arg1) { + const ret = debugString(getObject(arg1)); + const ptr0 = passStringToWasm0( + ret, + wasm.__wbindgen_malloc, + wasm.__wbindgen_realloc, + ); + const len0 = WASM_VECTOR_LEN; + getInt32Memory0()[arg0 / 4 + 1] = len0; + getInt32Memory0()[arg0 / 4 + 0] = ptr0; + }, __wbindgen_throw: function (arg0, arg1) { throw new Error(getStringFromWasm0(arg0, arg1)); }, @@ -370,7 +551,7 @@ let lastLoadPromise; * @param {InstantiateOptions=} opts * @returns {Promise<{ * instance: WebAssembly.Instance; - * exports: { wasm_tokenizer_from_json: typeof wasm_tokenizer_from_json; wasm_tokenizer_save: typeof wasm_tokenizer_save; wasm_bpe_default: typeof wasm_bpe_default; wasm_tokenizer_tokenize: typeof wasm_tokenizer_tokenize } + * exports: { wasm_tokenizer_from_json: typeof wasm_tokenizer_from_json; wasm_tokenizer_save: typeof wasm_tokenizer_save; wasm_bpe_default: typeof wasm_bpe_default; wasm_tokenizer_encode: typeof wasm_tokenizer_encode; wasm_tokenizer_get_vocab: typeof wasm_tokenizer_get_vocab; wasm_tokenizer_get_vocab_size: typeof wasm_tokenizer_get_vocab_size; wasm_tokenizer_decode: typeof wasm_tokenizer_decode; wasm_tokenizer_token_to_id: typeof wasm_tokenizer_token_to_id; wasm_tokenizer_id_to_token: typeof wasm_tokenizer_id_to_token } * }>} */ export function instantiateWithInstance(opts) { @@ -402,7 +583,12 @@ function getWasmInstanceExports() { wasm_tokenizer_from_json, wasm_tokenizer_save, wasm_bpe_default, - wasm_tokenizer_tokenize, + wasm_tokenizer_encode, + wasm_tokenizer_get_vocab, + wasm_tokenizer_get_vocab_size, + wasm_tokenizer_decode, + wasm_tokenizer_token_to_id, + wasm_tokenizer_id_to_token, }; } diff --git a/tokenizers/lib/netsaur_tokenizers_bg.wasm b/tokenizers/lib/netsaur_tokenizers_bg.wasm index ac62c2e..3e6f076 100644 Binary files a/tokenizers/lib/netsaur_tokenizers_bg.wasm and b/tokenizers/lib/netsaur_tokenizers_bg.wasm differ diff --git a/tokenizers/mod.ts b/tokenizers/mod.ts index 19505e6..33d9b92 100644 --- a/tokenizers/mod.ts +++ b/tokenizers/mod.ts @@ -1,8 +1,13 @@ import { instantiate, + wasm_tokenizer_decode, + wasm_tokenizer_encode, wasm_tokenizer_from_json, + wasm_tokenizer_get_vocab, + wasm_tokenizer_get_vocab_size, wasm_tokenizer_save, - wasm_tokenizer_tokenize, + wasm_tokenizer_id_to_token, + wasm_tokenizer_token_to_id, } from "./lib/netsaur_tokenizers.generated.js"; let initialized = false; @@ -29,12 +34,49 @@ export class Tokenizer { } /** - * Tokenize a sentence + * Get the vocab size + */ + getVocabSize(withAddedTokens = true) { + return wasm_tokenizer_get_vocab_size(this.#id, withAddedTokens); + } + + /** + * Get the vocab + */ + getVocab(withAddedTokens = true) { + return wasm_tokenizer_get_vocab(this.#id, withAddedTokens); + } + + /** + * Get the token from an id + */ + idToToken(id: number) { + return wasm_tokenizer_id_to_token(this.#id, id); + } + + /** + * Get the id from a token + */ + tokenToId(token: string) { + return wasm_tokenizer_token_to_id(this.#id, token); + } + + /** + * Encode a sentence * @param sentence sentence to tokenize * @returns */ - tokenize(sentence: string) { - return wasm_tokenizer_tokenize(this.#id, sentence); + encode(sentence: string) { + return wasm_tokenizer_encode(this.#id, sentence); + } + + /** + * Decode a sentence + * @param tokens tokens to decode + * @returns + */ + decode(ids: Uint32Array, skipSpecialTokens = false) { + return wasm_tokenizer_decode(this.#id, ids, skipSpecialTokens); } /**