-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpredict.ts
45 lines (39 loc) · 1.05 KB
/
predict.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import {
CPU,
type Rank,
Sequential,
setupBackend,
type Shape,
type Tensor,
tensor,
} from "../../packages/core/mod.ts";
import { loadDataset } from "./common.ts";
await setupBackend(CPU);
const network = Sequential.loadFile("examples/mnist/mnist.test.st");
const testSet = loadDataset("test-images.idx", "test-labels.idx", 0, 1000);
testSet.map((_, i) => (testSet[i].inputs.shape = [1, 28, 28]));
function argmax(mat: Tensor<Rank>) {
let max = -Infinity;
let index = -1;
for (let i = 0; i < mat.data.length; i++) {
if (mat.data[i] > max) {
max = mat.data[i];
index = i;
}
}
return index;
}
let correct = 0;
for (const test of testSet) {
const prediction = argmax(
await network.predict(
tensor(test.inputs.data, [1, ...test.inputs.shape] as Shape<Rank>),
),
);
const expected = argmax(test.outputs as Tensor<Rank>);
if (expected === prediction) {
correct += 1;
}
}
console.log(`${correct} / ${testSet.length} correct`);
console.log(`accuracy: ${((correct / testSet.length) * 100).toFixed(2)}%`);