Skip to content

Commit

Permalink
add binary classification example
Browse files Browse the repository at this point in the history
  • Loading branch information
retraigo committed Oct 21, 2023
1 parent ce0f132 commit 1fb9e7f
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 0 deletions.
1 change: 1 addition & 0 deletions deno.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"example:xor-gpu": "deno run -A --unstable ./examples/xor_gpu.ts",
"example:xor-wasm": "deno run -A --unstable ./examples/xor_wasm.ts",
"example:linear": "deno run -A --unstable ./examples/linear.ts",
"example:binary": "deno run -A --unstable ./examples/classification/binary_iris.ts",
"example:filters": "deno run -A --unstable examples/filters/conv.ts ",
"example:train": "deno run -A --unstable examples/model/train.ts ",
"example:run": "deno run -A --unstable examples/model/run.ts ",
Expand Down
3 changes: 3 additions & 0 deletions examples/classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Binary Classification
This example showcases binary classification on the Iris dataset.
The `Iris Virginica` class is omitted for this example.
100 changes: 100 additions & 0 deletions examples/classification/binary_iris.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
5.1,3.5,1.4,.2,"Setosa"
4.9,3,1.4,.2,"Setosa"
4.7,3.2,1.3,.2,"Setosa"
4.6,3.1,1.5,.2,"Setosa"
5,3.6,1.4,.2,"Setosa"
5.4,3.9,1.7,.4,"Setosa"
4.6,3.4,1.4,.3,"Setosa"
5,3.4,1.5,.2,"Setosa"
4.4,2.9,1.4,.2,"Setosa"
4.9,3.1,1.5,.1,"Setosa"
5.4,3.7,1.5,.2,"Setosa"
4.8,3.4,1.6,.2,"Setosa"
4.8,3,1.4,.1,"Setosa"
4.3,3,1.1,.1,"Setosa"
5.8,4,1.2,.2,"Setosa"
5.7,4.4,1.5,.4,"Setosa"
5.4,3.9,1.3,.4,"Setosa"
5.1,3.5,1.4,.3,"Setosa"
5.7,3.8,1.7,.3,"Setosa"
5.1,3.8,1.5,.3,"Setosa"
5.4,3.4,1.7,.2,"Setosa"
5.1,3.7,1.5,.4,"Setosa"
4.6,3.6,1,.2,"Setosa"
5.1,3.3,1.7,.5,"Setosa"
4.8,3.4,1.9,.2,"Setosa"
5,3,1.6,.2,"Setosa"
5,3.4,1.6,.4,"Setosa"
5.2,3.5,1.5,.2,"Setosa"
5.2,3.4,1.4,.2,"Setosa"
4.7,3.2,1.6,.2,"Setosa"
4.8,3.1,1.6,.2,"Setosa"
5.4,3.4,1.5,.4,"Setosa"
5.2,4.1,1.5,.1,"Setosa"
5.5,4.2,1.4,.2,"Setosa"
4.9,3.1,1.5,.2,"Setosa"
5,3.2,1.2,.2,"Setosa"
5.5,3.5,1.3,.2,"Setosa"
4.9,3.6,1.4,.1,"Setosa"
4.4,3,1.3,.2,"Setosa"
5.1,3.4,1.5,.2,"Setosa"
5,3.5,1.3,.3,"Setosa"
4.5,2.3,1.3,.3,"Setosa"
4.4,3.2,1.3,.2,"Setosa"
5,3.5,1.6,.6,"Setosa"
5.1,3.8,1.9,.4,"Setosa"
4.8,3,1.4,.3,"Setosa"
5.1,3.8,1.6,.2,"Setosa"
4.6,3.2,1.4,.2,"Setosa"
5.3,3.7,1.5,.2,"Setosa"
5,3.3,1.4,.2,"Setosa"
7,3.2,4.7,1.4,"Versicolor"
6.4,3.2,4.5,1.5,"Versicolor"
6.9,3.1,4.9,1.5,"Versicolor"
5.5,2.3,4,1.3,"Versicolor"
6.5,2.8,4.6,1.5,"Versicolor"
5.7,2.8,4.5,1.3,"Versicolor"
6.3,3.3,4.7,1.6,"Versicolor"
4.9,2.4,3.3,1,"Versicolor"
6.6,2.9,4.6,1.3,"Versicolor"
5.2,2.7,3.9,1.4,"Versicolor"
5,2,3.5,1,"Versicolor"
5.9,3,4.2,1.5,"Versicolor"
6,2.2,4,1,"Versicolor"
6.1,2.9,4.7,1.4,"Versicolor"
5.6,2.9,3.6,1.3,"Versicolor"
6.7,3.1,4.4,1.4,"Versicolor"
5.6,3,4.5,1.5,"Versicolor"
5.8,2.7,4.1,1,"Versicolor"
6.2,2.2,4.5,1.5,"Versicolor"
5.6,2.5,3.9,1.1,"Versicolor"
5.9,3.2,4.8,1.8,"Versicolor"
6.1,2.8,4,1.3,"Versicolor"
6.3,2.5,4.9,1.5,"Versicolor"
6.1,2.8,4.7,1.2,"Versicolor"
6.4,2.9,4.3,1.3,"Versicolor"
6.6,3,4.4,1.4,"Versicolor"
6.8,2.8,4.8,1.4,"Versicolor"
6.7,3,5,1.7,"Versicolor"
6,2.9,4.5,1.5,"Versicolor"
5.7,2.6,3.5,1,"Versicolor"
5.5,2.4,3.8,1.1,"Versicolor"
5.5,2.4,3.7,1,"Versicolor"
5.8,2.7,3.9,1.2,"Versicolor"
6,2.7,5.1,1.6,"Versicolor"
5.4,3,4.5,1.5,"Versicolor"
6,3.4,4.5,1.6,"Versicolor"
6.7,3.1,4.7,1.5,"Versicolor"
6.3,2.3,4.4,1.3,"Versicolor"
5.6,3,4.1,1.3,"Versicolor"
5.5,2.5,4,1.3,"Versicolor"
5.5,2.6,4.4,1.2,"Versicolor"
6.1,3,4.6,1.4,"Versicolor"
5.8,2.6,4,1.2,"Versicolor"
5,2.3,3.3,1,"Versicolor"
5.6,2.7,4.2,1.3,"Versicolor"
5.7,3,4.2,1.2,"Versicolor"
5.7,2.9,4.2,1.3,"Versicolor"
6.2,2.9,4.3,1.3,"Versicolor"
5.1,2.5,3,1.1,"Versicolor"
5.7,2.8,4.1,1.3,"Versicolor"
101 changes: 101 additions & 0 deletions examples/classification/binary_iris.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import {
Cost,
CPU,
DenseLayer,
Sequential,
setupBackend,
SigmoidLayer,
tensor1D,
tensor2D,
} from "../../mod.ts";

import { parse } from "https://deno.land/[email protected]/csv/parse.ts";

// Import helpers for metrics
import {
accuracyScore,
// Metrics
ConfusionMatrix,
precisionScore,
sensitivityScore,
specificityScore,
// Split the dataset
useSplit,
} from "https://deno.land/x/[email protected]/mod.ts";

// Define classes
const classes = ["Setosa", "Versicolor"];

// Read the training dataset
const _data = Deno.readTextFileSync("examples/classification/binary_iris.csv");
const data = parse(_data);

// Get the predictors (x) and targets (y)
const x = data.map((fl) => fl.slice(0, 4).map(Number));
const y = data.map((fl) => classes.indexOf(fl[4]));

// Split the dataset for training and testing
const [train, test] = useSplit({ ratio: [7, 3], shuffle: true }, x, y) as [
[typeof x, typeof y],
[typeof x, typeof y],
];

// Setup the CPU backend for Netsaur
await setupBackend(CPU);

// Create a sequential neural network
const net = new Sequential({
// Set number of minibatches to 4
// Set size of output to 4
size: [4, 4],

// Disable logging during training
silent: false,

// Define each layer of the network
layers: [
// A dense layer with 4 neurons
DenseLayer({ size: [4] }),
// A sigmoid activation layer
SigmoidLayer(),
// A dense layer with 1 neuron
DenseLayer({ size: [1] }),
// Another sigmoid layer
SigmoidLayer(),
],
// We are using MSE for finding cost
cost: Cost.MSE,
});

const time = performance.now();

// Train the network
net.train(
[
{
inputs: tensor2D(train[0]),
outputs: tensor2D(train[1].map((x) => [x])),
},
],
// Train for 10000 epochs
10000,
);

console.log(`training time: ${performance.now() - time}ms`);

// Calculate metrics
let [tp, fn, fp, tn] = [0, 0, 0, 0];
for (let i = 0; i < test[0].length; i += 1) {
const res = (await net.predict(tensor1D(test[0][i]))).data[0] < 0.5 ? 0 : 1;
if (res === 1 && test[1][i] == 1) tp += 1;
if (res === 0 && test[1][i] == 1) fn += 1;
if (res === 1 && test[1][i] == 0) fp += 1;
if (res === 0 && test[1][i] == 0) tn += 1;
}

const cMatrix = new ConfusionMatrix([tp, fn, fp, tn]);
console.log("Confusion Matrix: ", cMatrix);
console.log("Accuracy: ", `${accuracyScore(cMatrix) * 100}%`);
console.log("Precision: ", `${precisionScore(cMatrix) * 100}%`);
console.log("Sensitivity / Recall: ", `${sensitivityScore(cMatrix) * 100}%`);
console.log("Specificity: ", `${specificityScore(cMatrix) * 100}%`);

0 comments on commit 1fb9e7f

Please sign in to comment.