diff --git a/Cargo.lock b/Cargo.lock index 3abdc93..43fcdce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -330,22 +330,24 @@ dependencies = [ [[package]] name = "ndarray" -version = "0.15.6" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" dependencies = [ "matrixmultiply", "num-complex", "num-integer", "num-traits", + "portable-atomic", + "portable-atomic-util", "rawpointer", ] [[package]] name = "ndarray-rand" -version = "0.14.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65608f937acc725f5b164dcf40f4f0bc5d67dc268ab8a649d3002606718c4588" +checksum = "f093b3db6fd194718dcdeea6bd8c829417deae904e3fcc7732dabcd4416d25d8" dependencies = [ "ndarray", "rand", @@ -445,6 +447,21 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "portable-atomic" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" + +[[package]] +name = "portable-atomic-util" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcdd8420072e66d54a407b3316991fe946ce3ab1083a7f575b2463866624704d" +dependencies = [ + "portable-atomic", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -583,9 +600,9 @@ checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" [[package]] name = "safetensors" -version = "0.4.0" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1e5186bd51ae3f90999d243853f5e8cb51f3467f55da42dc611ed2342483dad" +checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" dependencies = [ "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 657fd35..c9397d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,3 +2,12 @@ package.version = "0.4.0" members = ["crates/*"] resolver = "2" + +[workspace.dependencies] +cudarc = "0.9.14" +ndarray = "0.16.1" +ndarray-rand = "0.15.0" +safetensors = "0.4.5" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "1.0.49" \ No newline at end of file diff --git a/bench/netsaur_cpu.ts b/bench/netsaur_cpu.ts index 74f749b..49ca7f5 100644 --- a/bench/netsaur_cpu.ts +++ b/bench/netsaur_cpu.ts @@ -11,72 +11,34 @@ import { await setupBackend(CPU); -Deno.bench( - { name: "xor 10000 epochs", permissions: "inherit" }, - async () => { - const net = new Sequential({ - size: [4, 2], - silent: true, - layers: [ - DenseLayer({ size: [3] }), - SigmoidLayer(), - DenseLayer({ size: [1] }), - SigmoidLayer(), - ], - cost: Cost.MSE, - }); - - net.train( - [ - { - inputs: tensor2D([ - [0, 0], - [1, 0], - [0, 1], - [1, 1], - ]), - outputs: tensor2D([[0], [1], [1], [0]]), - }, - ], - 10000, - ); - - console.log((await net.predict(tensor1D([0, 0]))).data); - console.log((await net.predict(tensor1D([1, 0]))).data); - console.log((await net.predict(tensor1D([0, 1]))).data); - console.log((await net.predict(tensor1D([1, 1]))).data); - }, +const net = new Sequential({ + size: [4, 2], + silent: true, + layers: [ + DenseLayer({ size: [3] }), + SigmoidLayer(), + DenseLayer({ size: [1] }), + SigmoidLayer(), + ], + cost: Cost.MSE, +}); + +net.train( + [ + { + inputs: tensor2D([ + [0, 0], + [1, 0], + [0, 1], + [1, 1], + ]), + outputs: tensor2D([[0], [1], [1], [0]]), + }, + ], + 10000, ); -// const net = new NeuralNetwork({ -// size: [4, 2], -// silent: true, -// layers: [ -// DenseLayer({ size: [3], activation: Activation.Sigmoid }), -// DenseLayer({ size: [1], activation: Activation.Sigmoid }), -// ], -// cost: Cost.MSE, -// }); - -// const time = performance.now(); - -// net.train( -// [ -// { -// inputs: tensor2D([ -// [0, 0], -// [1, 0], -// [0, 1], -// [1, 1], -// ]), -// outputs: tensor2D([[0], [1], [1], [0]]), -// }, -// ], -// 10000, -// ) - -// console.log(`training time: ${performance.now() - time}ms`); -// console.log((await net.predict(tensor2D([[0, 0]]))).data); -// console.log((await net.predict(tensor2D([[1, 0]]))).data); -// console.log((await net.predict(tensor2D([[0, 1]]))).data); -// console.log((await net.predict(tensor2D([[1, 1]]))).data); +console.log((await net.predict(tensor1D([0, 0]))).data); +console.log((await net.predict(tensor1D([1, 0]))).data); +console.log((await net.predict(tensor1D([0, 1]))).data); +console.log((await net.predict(tensor1D([1, 1]))).data); diff --git a/bench/netsaur_wasm.ts b/bench/netsaur_wasm.ts index b8b39ab..3762696 100644 --- a/bench/netsaur_wasm.ts +++ b/bench/netsaur_wasm.ts @@ -11,72 +11,34 @@ import { await setupBackend(WASM); -Deno.bench( - { name: "xor 10000 epochs", permissions: "inherit" }, - async () => { - const net = new Sequential({ - size: [4, 2], - silent: true, - layers: [ - DenseLayer({ size: [3] }), - SigmoidLayer(), - DenseLayer({ size: [1] }), - SigmoidLayer(), - ], - cost: Cost.MSE, - }); - - net.train( - [ - { - inputs: tensor2D([ - [0, 0], - [1, 0], - [0, 1], - [1, 1], - ]), - outputs: tensor2D([[0], [1], [1], [0]]), - }, - ], - 10000, - ); - - console.log((await net.predict(tensor1D([0, 0]))).data); - console.log((await net.predict(tensor1D([1, 0]))).data); - console.log((await net.predict(tensor1D([0, 1]))).data); - console.log((await net.predict(tensor1D([1, 1]))).data); - }, +const net = new Sequential({ + size: [4, 2], + silent: true, + layers: [ + DenseLayer({ size: [3] }), + SigmoidLayer(), + DenseLayer({ size: [1] }), + SigmoidLayer(), + ], + cost: Cost.MSE, +}); + +net.train( + [ + { + inputs: tensor2D([ + [0, 0], + [1, 0], + [0, 1], + [1, 1], + ]), + outputs: tensor2D([[0], [1], [1], [0]]), + }, + ], + 10000, ); -// const net = new NeuralNetwork({ -// size: [4, 2], -// silent: true, -// layers: [ -// DenseLayer({ size: [3], activation: Activation.Sigmoid }), -// DenseLayer({ size: [1], activation: Activation.Sigmoid }), -// ], -// cost: Cost.MSE, -// }); - -// const time = performance.now(); - -// net.train( -// [ -// { -// inputs: tensor2D([ -// [0, 0], -// [1, 0], -// [0, 1], -// [1, 1], -// ]), -// outputs: tensor2D([[0], [1], [1], [0]]), -// }, -// ], -// 10000, -// ) - -// console.log(`training time: ${performance.now() - time}ms`); -// console.log((await net.predict(tensor2D([[0, 0]]))).data); -// console.log((await net.predict(tensor2D([[1, 0]]))).data); -// console.log((await net.predict(tensor2D([[0, 1]]))).data); -// console.log((await net.predict(tensor2D([[1, 1]]))).data); +console.log((await net.predict(tensor1D([0, 0]))).data); +console.log((await net.predict(tensor1D([1, 0]))).data); +console.log((await net.predict(tensor1D([0, 1]))).data); +console.log((await net.predict(tensor1D([1, 1]))).data); diff --git a/bench/node/brain_cpu.js b/bench/node/brain_cpu.js deleted file mode 100644 index 381bc96..0000000 --- a/bench/node/brain_cpu.js +++ /dev/null @@ -1,27 +0,0 @@ -const brain = require("brain.js"); - -const time = performance.now(); - -const config = { - binaryThresh: 0.5, - hiddenLayers: [4], - activation: "sigmoid", - leakyReluAlpha: 0.01, -}; - -const net = new brain.NeuralNetwork(config); - -net.train([ - { input: [0, 0], output: [0] }, - { input: [1, 0], output: [1] }, - { input: [0, 1], output: [1] }, - { input: [1, 1], output: [0] }, -], { - iterations: 10000, -}); - -console.log(net.run([0, 0])); -console.log(net.run([1, 0])); -console.log(net.run([0, 1])); -console.log(net.run([1, 1])); -console.log(`time: ${performance.now() - time}ms`); diff --git a/bench/node/brain_gpu.js b/bench/node/brain_gpu.js deleted file mode 100644 index 4f4f542..0000000 --- a/bench/node/brain_gpu.js +++ /dev/null @@ -1,26 +0,0 @@ -const brain = require("brain.js"); - -const config = { - binaryThresh: 0.5, - hiddenLayers: [4], - activation: "sigmoid", - leakyReluAlpha: 0.01, -}; - -const net = new brain.NeuralNetworkGPU(config); -const time = performance.now(); - -net.train([ - { input: [0, 0], output: [0] }, - { input: [1, 0], output: [1] }, - { input: [0, 1], output: [1] }, - { input: [1, 1], output: [0] }, -], { - iterations: 10000, -}); - -console.log(`training time: ${performance.now() - time}ms`); -console.log(net.run([0, 0])); -console.log(net.run([1, 0])); -console.log(net.run([0, 1])); -console.log(net.run([1, 1])); diff --git a/bench/node/package.json b/bench/node/package.json deleted file mode 100644 index 6799ccf..0000000 --- a/bench/node/package.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "name": "bench", - "version": "1.0.0", - "description": "", - "main": "index.js", - "scripts": { - "test": "echo \"Error: no test specified\" && exit 1" - }, - "author": "", - "license": "ISC", - "dependencies": { - "@tensorflow/tfjs": "^4.4.0", - "@tensorflow/tfjs-node": "^4.4.0", - "@tensorflow/tfjs-node-gpu": "^4.4.0", - "brain.js": "^2.0.0-beta.23", - "tfjs": "^0.6.0" - } -} diff --git a/bench/node/tfjs_cpu.js b/bench/node/tfjs_cpu.js deleted file mode 100644 index 930b9c0..0000000 --- a/bench/node/tfjs_cpu.js +++ /dev/null @@ -1,23 +0,0 @@ -const tf = require("@tensorflow/tfjs-node"); - -async function predictOutput() { - const time = performance.now(); - const model = tf.sequential(); - model.add(tf.layers.dense({ units: 8, inputShape: 2, activation: "tanh" })); - model.add(tf.layers.dense({ units: 1, activation: "sigmoid" })); - model.compile({ optimizer: "sgd", loss: "meanSquaredError", lr: 0.6 }); - - // Creating dataset - const xs = tf.tensor2d([[0, 0], [0, 1], [1, 0], [1, 1]]); - const ys = tf.tensor2d([[0], [1], [1], [0]]); - - // Train the model - await model.fit(xs, ys, { - batchSize: 1, - epochs: 10000, - verbose: false, - }); - console.log(`training time: ${performance.now() - time}ms`); -} - -predictOutput(); diff --git a/crates/core-gpu/Cargo.toml b/crates/core-gpu/Cargo.toml index 028a8c3..581cce6 100644 --- a/crates/core-gpu/Cargo.toml +++ b/crates/core-gpu/Cargo.toml @@ -6,12 +6,11 @@ version = { workspace = true } [lib] crate-type = ["cdylib"] - [dependencies] -ndarray = "0.15.6" -ndarray-rand = "0.14.0" -serde = {version = "1.0", features = ["derive"]} -serde_json = "1.0" -safetensors = "0.4.0" -cudarc = "0.9.14" -thiserror = "1.0.49" +cudarc = { workspace = true } +ndarray = { workspace = true } +ndarray-rand = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index e5f7cc1..f28d4b8 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -7,11 +7,11 @@ version = { workspace = true } crate-type = ["cdylib"] [dependencies] -ndarray = "0.15.6" -ndarray-rand = "0.14.0" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -safetensors = "0.4.0" +ndarray = { workspace = true } +ndarray-rand = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +safetensors = { workspace = true } [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = "0.2.92" diff --git a/crates/core/src/cpu/layers/flatten.rs b/crates/core/src/cpu/layers/flatten.rs index 63679ba..e1ee285 100644 --- a/crates/core/src/cpu/layers/flatten.rs +++ b/crates/core/src/cpu/layers/flatten.rs @@ -34,10 +34,10 @@ impl FlattenCPULayer { pub fn forward_propagate(&mut self, inputs: ArrayD) -> ArrayD { let output_size = IxDyn(&self.output_size); - inputs.into_shape(output_size).unwrap() + inputs.into_shape_with_order(output_size).unwrap() } pub fn backward_propagate(&mut self, d_outputs: ArrayD) -> ArrayD { - d_outputs.into_shape(self.input_size.clone()).unwrap() + d_outputs.into_shape_with_order(self.input_size.clone()).unwrap() } } diff --git a/crates/tokenizers/Cargo.toml b/crates/tokenizers/Cargo.toml index 2538ca6..9693430 100644 --- a/crates/tokenizers/Cargo.toml +++ b/crates/tokenizers/Cargo.toml @@ -7,10 +7,10 @@ version = { workspace = true } crate-type = ["cdylib"] [dependencies] -ndarray = "0.15.6" -ndarray-rand = "0.14.0" -serde = {version = "1.0", features = ["derive"]} -serde_json = "1.0" +ndarray = { workspace = true } +ndarray-rand = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } serde-wasm-bindgen = "0.6.0" tokenizers = { version="0.20.0", default-features=false, features = ["unstable_wasm"]} wasm-bindgen = "0.2.92" diff --git a/deno.jsonc b/deno.jsonc index 1f86198..5fd533e 100644 --- a/deno.jsonc +++ b/deno.jsonc @@ -29,6 +29,7 @@ "./visualizer": "./packages/visualizer/mod.ts" }, "tasks": { + // Examples "example:xor": "deno -A ./examples/xor_auto.ts", "example:xor-option": "deno -A ./examples/xor_option.ts", "example:xor-cpu": "deno -A ./examples/xor_cpu.ts", @@ -48,10 +49,18 @@ "example:mnist-train": "deno -A examples/mnist/train.ts", "example:mnist-predict": "deno -A examples/mnist/predict.ts", "example:tokenizers-basic": "deno -A examples/tokenizers/basic.ts", + + // Benchmarks + "bench:netsaur-cpu": "deno -A bench/netsaur_cpu.ts", + "bench:netsaur-wasm": "deno -A bench/netsaur_wasm.ts", + "bench:netsaur": "deno run bench:netsaur-cpu && deno run bench:netsaur-wasm", + "bench:torch-cpu": "python bench/torch_cpu.py", + + // Build "build": "deno run build:cpu && deno run build:wasm && deno run build:tokenizers", "build:cpu": "cargo build --release -p netsaur", "build:gpu": "cargo build --release -p netsaur-gpu", - "build:wasm": "deno -Ar jsr:@deno/wasmbuild@0.17.2 -p netsaur --out src/backends/wasm/lib", - "build:tokenizers": "deno -Ar jsr:@deno/wasmbuild@0.17.2 -p netsaur-tokenizers --out tokenizers/lib" + "build:wasm": "deno -Ar jsr:@deno/wasmbuild@0.17.2 -p netsaur --out packages/core/src/backends/wasm/lib", + "build:tokenizers": "deno -Ar jsr:@deno/wasmbuild@0.17.2 -p netsaur-tokenizers --out packages/tokenizers/lib" } -} \ No newline at end of file +} diff --git a/deno.lock b/deno.lock new file mode 100644 index 0000000..3c921db --- /dev/null +++ b/deno.lock @@ -0,0 +1,45 @@ +{ + "version": "4", + "specifiers": { + "jsr:@denosaurs/plug@1.0.3": "1.0.3", + "jsr:@std/assert@~0.213.1": "0.213.1", + "jsr:@std/encoding@0.213.1": "0.213.1", + "jsr:@std/fmt@0.213.1": "0.213.1", + "jsr:@std/fs@0.213.1": "0.213.1", + "jsr:@std/path@0.213.1": "0.213.1", + "jsr:@std/path@~0.213.1": "0.213.1" + }, + "jsr": { + "@denosaurs/plug@1.0.3": { + "integrity": "b010544e386bea0ff3a1d05e0c88f704ea28cbd4d753439c2f1ee021a85d4640", + "dependencies": [ + "jsr:@std/encoding", + "jsr:@std/fmt", + "jsr:@std/fs", + "jsr:@std/path@0.213.1" + ] + }, + "@std/assert@0.213.1": { + "integrity": "24c28178b30c8e0782c18e8e94ea72b16282207569cdd10ffb9d1d26f2edebfe" + }, + "@std/encoding@0.213.1": { + "integrity": "fcbb6928713dde941a18ca5db88ca1544d0755ec8fb20fe61e2dc8144b390c62" + }, + "@std/fmt@0.213.1": { + "integrity": "a06d31777566d874b9c856c10244ac3e6b660bdec4c82506cd46be052a1082c3" + }, + "@std/fs@0.213.1": { + "integrity": "fbcaf099f8a85c27ab0712b666262cda8fe6d02e9937bf9313ecaea39a22c501", + "dependencies": [ + "jsr:@std/assert", + "jsr:@std/path@~0.213.1" + ] + }, + "@std/path@0.213.1": { + "integrity": "f187bf278a172752e02fcbacf6bd78a335ed320d080a7ed3a5a59c3e88abc673", + "dependencies": [ + "jsr:@std/assert" + ] + } + } +} diff --git a/examples/mnist/download.ts b/examples/mnist/download.ts index 61ed2d8..0063249 100644 --- a/examples/mnist/download.ts +++ b/examples/mnist/download.ts @@ -12,18 +12,18 @@ async function download(url: string, to: string) { } await download( - "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", + "train-images-idx3-ubyte.gz", "train-images.idx", ); await download( - "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", + "./train-labels-idx1-ubyte.gz", "train-labels.idx", ); await download( - "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", + "./t10k-images-idx3-ubyte.gz", "test-images.idx", ); await download( - "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", + "./t10k-labels-idx1-ubyte.gz", "test-labels.idx", ); diff --git a/examples/mnist/train.ts b/examples/mnist/train.ts index 61013c8..598ff17 100644 --- a/examples/mnist/train.ts +++ b/examples/mnist/train.ts @@ -42,11 +42,11 @@ const trainSet = loadDataset( "train-images.idx", "train-labels.idx", 0, - 5000, + 10000, 32, ); -const epochs = 1; +const epochs = 10; console.log("Training (" + epochs + " epochs)..."); const start = performance.now(); network.train(trainSet, epochs, 1, 0.005); diff --git a/packages/core/src/backends/wasm/lib/netsaur.generated.js b/packages/core/src/backends/wasm/lib/netsaur.generated.js index 05940c6..47661c0 100644 --- a/packages/core/src/backends/wasm/lib/netsaur.generated.js +++ b/packages/core/src/backends/wasm/lib/netsaur.generated.js @@ -4,7 +4,7 @@ // deno-fmt-ignore-file /// -// source-hash: 24bcfd72e3631f0f05524ea722a11a2b75b29185 +// source-hash: c1eff57085f8488444a8499d3d2fcad1650a7099 let wasm; let cachedInt32Memory0; @@ -12,21 +12,12 @@ const heap = new Array(128).fill(undefined); heap.push(undefined, null, true, false); -let heap_next = heap.length; - -function addHeapObject(obj) { - if (heap_next === heap.length) heap.push(heap.length + 1); - const idx = heap_next; - heap_next = heap[idx]; - - heap[idx] = obj; - return idx; -} - function getObject(idx) { return heap[idx]; } +let heap_next = heap.length; + function dropObject(idx) { if (idx < 132) return; heap[idx] = heap_next; @@ -39,6 +30,15 @@ function takeObject(idx) { return ret; } +function addHeapObject(obj) { + if (heap_next === heap.length) heap.push(heap.length + 1); + const idx = heap_next; + heap_next = heap[idx]; + + heap[idx] = obj; + return idx; +} + const cachedTextDecoder = typeof TextDecoder !== "undefined" ? new TextDecoder("utf-8", { ignoreBOM: true, fatal: true }) : { @@ -214,6 +214,9 @@ function handleError(f, args) { const imports = { __wbindgen_placeholder__: { + __wbindgen_object_drop_ref: function (arg0) { + takeObject(arg0); + }, __wbg_log_023d7669e382bddf: function (arg0, arg1) { console.log(getStringFromWasm0(arg0, arg1)); }, @@ -221,9 +224,6 @@ const imports = { const ret = arg0; return addHeapObject(ret); }, - __wbindgen_object_drop_ref: function (arg0) { - takeObject(arg0); - }, __wbg_crypto_c48a774b022d20ac: function (arg0) { const ret = getObject(arg0).crypto; return addHeapObject(ret); diff --git a/packages/core/src/backends/wasm/lib/netsaur_bg.wasm b/packages/core/src/backends/wasm/lib/netsaur_bg.wasm index 46420ab..2ffee65 100644 Binary files a/packages/core/src/backends/wasm/lib/netsaur_bg.wasm and b/packages/core/src/backends/wasm/lib/netsaur_bg.wasm differ diff --git a/packages/tokenizers/lib/netsaur_tokenizers.generated.js b/packages/tokenizers/lib/netsaur_tokenizers.generated.js index 97e299c..af29a1f 100644 --- a/packages/tokenizers/lib/netsaur_tokenizers.generated.js +++ b/packages/tokenizers/lib/netsaur_tokenizers.generated.js @@ -4,7 +4,7 @@ // deno-fmt-ignore-file /// -// source-hash: 7411764ea128cc65962626e00ed109dfa6e5ba46 +// source-hash: e11fff4f445d516e52460d2ff206a6a083ca5f09 let wasm; const heap = new Array(128).fill(undefined); diff --git a/packages/tokenizers/lib/netsaur_tokenizers_bg.wasm b/packages/tokenizers/lib/netsaur_tokenizers_bg.wasm index 3181915..817caf3 100644 Binary files a/packages/tokenizers/lib/netsaur_tokenizers_bg.wasm and b/packages/tokenizers/lib/netsaur_tokenizers_bg.wasm differ diff --git a/packages/utilities/src/text/preprocess/tokenize/split.ts b/packages/utilities/src/text/preprocess/tokenize/split.ts index 046d370..ba02928 100644 --- a/packages/utilities/src/text/preprocess/tokenize/split.ts +++ b/packages/utilities/src/text/preprocess/tokenize/split.ts @@ -77,12 +77,11 @@ export class SplitTokenizer { let i = 0; while (i < text.length) { res[i] = this.#transform(text[i], size); - i += 1; + i++; } return res; - } else { - return [this.#transform(text, 0)]; } + return [this.#transform(text, 0)]; } #transform(text: string, size: number): number[] { const words = this.split(text); @@ -93,15 +92,13 @@ export class SplitTokenizer { while (i < words.length && i < size) { if (this.vocabulary.has(words[i])) { const index = this.vocabulary.get(words[i]); - if (typeof index === "number") { - res[i] = index; - } else { - res[i] = this.vocabulary.get("__unk__") || 0; - } + res[i] = typeof index === "number" + ? index + : this.vocabulary.get("__unk__") || 0; } else { res[i] = this.vocabulary.get("__unk__") || 0; } - i += 1; + i++; } return res; }