Skip to content

Commit

Permalink
Merge pull request #38 from Samyssmile/feature/unified-classifier
Browse files Browse the repository at this point in the history
Feature/unified classifier
  • Loading branch information
Samyssmile authored Oct 10, 2023
2 parents a2afa7f + 4cbc074 commit f0c0ae8
Show file tree
Hide file tree
Showing 32 changed files with 1,175 additions and 641 deletions.
143 changes: 143 additions & 0 deletions example/src/main/java/de/example/benchmark/Benchmark.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package de.example.benchmark;

import de.edux.api.Classifier;
import de.edux.functions.activation.ActivationFunction;
import de.edux.functions.initialization.Initialization;
import de.edux.functions.loss.LossFunction;
import de.edux.ml.decisiontree.DecisionTree;
import de.edux.ml.knn.KnnClassifier;
import de.edux.ml.nn.config.NetworkConfiguration;
import de.edux.ml.nn.network.MultilayerPerceptron;
import de.edux.ml.randomforest.RandomForest;
import de.edux.ml.svm.SVMKernel;
import de.edux.ml.svm.SupportVectorMachine;
import de.example.data.seaborn.Penguin;
import de.example.data.seaborn.SeabornDataProcessor;

import java.util.ArrayList;
import java.util.Map;
import java.io.File;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;

/**
* Compare the performance of different classifiers
*/
public class Benchmark {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean FILTER_INCOMPLETE_RECORDS = true;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.75;
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv");
private double[][] trainFeatures;
private double[][] trainLabels;
private double[][] testFeatures;
private double[][] testLabels;
private MultilayerPerceptron multilayerPerceptron;
private NetworkConfiguration networkConfiguration;

public static void main(String[] args) {
new Benchmark().run();
}

private void run() {
initFeaturesAndLabels();

Classifier knn = new KnnClassifier(2);
Classifier decisionTree = new DecisionTree(8, 2, 1, 3);
Classifier randomForest = new RandomForest(100, 10, 2, 1, 3, 60);
Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 1);

networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128,256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER);
multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels);
Map<String, Classifier> classifiers = Map.of(
"KNN", knn,
"DecisionTree", decisionTree,
"RandomForest", randomForest,
"SVM", svm,
"MLP", multilayerPerceptron
);

Map<String, List<Double>> results = new ConcurrentHashMap<>();
results.put("KNN", new ArrayList<>());
results.put("DecisionTree", new ArrayList<>());
results.put("RandomForest", new ArrayList<>());
results.put("SVM", new ArrayList<>());
results.put("MLP", new ArrayList<>());


IntStream.range(0, 50).forEach(i -> {
knn.train(trainFeatures, trainLabels);
decisionTree.train(trainFeatures, trainLabels);
randomForest.train(trainFeatures, trainLabels);
svm.train(trainFeatures, trainLabels);
multilayerPerceptron.train(trainFeatures, trainLabels);

double knnAccuracy = knn.evaluate(testFeatures, testLabels);
double decisionTreeAccuracy = decisionTree.evaluate(testFeatures, testLabels);
double randomForestAccuracy = randomForest.evaluate(testFeatures, testLabels);
double svmAccuracy = svm.evaluate(testFeatures, testLabels);
double multilayerPerceptronAccuracy = multilayerPerceptron.evaluate(testFeatures, testLabels);

results.get("KNN").add(knnAccuracy);
results.get("DecisionTree").add(decisionTreeAccuracy);
results.get("RandomForest").add(randomForestAccuracy);
results.get("SVM").add(svmAccuracy);
results.get("MLP").add(multilayerPerceptronAccuracy);
initFeaturesAndLabels();
updateMLP(testFeatures, testLabels);
});


//Sort and print results with numeration begin with best average accuracy
System.out.println("Classifier performances (sorted by average accuracy):");
results.entrySet().stream()
.map(entry -> {
double avgAccuracy = entry.getValue().stream()
.mapToDouble(Double::doubleValue)
.average()
.orElse(0.0);
return Map.entry(entry.getKey(), avgAccuracy);
})
.sorted(Map.Entry.<String, Double>comparingByValue().reversed())
.forEachOrdered(entry -> {
System.out.printf("%s: %.2f%%\n", entry.getKey(), entry.getValue() * 100);
});

// Additionally, if you want to show other metrics, such as minimum or maximum accuracy, you can calculate and display them similarly.
System.out.println("\nClassifier best and worst performances:");
results.forEach((classifierName, accuracies) -> {
double maxAccuracy = accuracies.stream()
.mapToDouble(Double::doubleValue)
.max()
.orElse(0.0);
double minAccuracy = accuracies.stream()
.mapToDouble(Double::doubleValue)
.min()
.orElse(0.0);
System.out.printf("%s: Best: %.2f%%, Worst: %.2f%%\n", classifierName, maxAccuracy * 100, minAccuracy * 100);
});


}

private void updateMLP(double[][] testFeatures, double[][] testLabels) {
multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels);
}

private void initFeaturesAndLabels() {
var seabornDataProcessor = new SeabornDataProcessor();
List<Penguin> data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS);
seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO);

trainFeatures = seabornDataProcessor.getTrainFeatures();
trainLabels = seabornDataProcessor.getTrainLabels();

testFeatures = seabornDataProcessor.getTestFeatures();
testLabels = seabornDataProcessor.getTestLabels();


}
}
4 changes: 3 additions & 1 deletion example/src/main/java/de/example/data/iris/Iris.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package de.example.data.iris;

public class Iris {

public class Iris{
public double sepalLength;
public double sepalWidth;
public double petalLength;
Expand Down Expand Up @@ -28,4 +29,5 @@ public String toString() {
public double[] getFeatures() {
return new double[]{sepalLength, sepalWidth, petalLength, petalWidth};
}

}
133 changes: 133 additions & 0 deletions example/src/main/java/de/example/data/iris/IrisDataProcessor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package de.example.data.iris;

import de.edux.data.provider.DataProcessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;

public class IrisDataProcessor extends DataProcessor<Iris> {
private static final Logger LOG = LoggerFactory.getLogger(IrisDataProcessor.class);
private double[][] targets;

@Override
public void normalize(List<Iris> rowDataset) {
double minSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).min().getAsDouble();
double maxSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).max().getAsDouble();
double minSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).min().getAsDouble();
double maxSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).max().getAsDouble();
double minPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).min().getAsDouble();
double maxPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).max().getAsDouble();
double minPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).min().getAsDouble();
double maxPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).max().getAsDouble();

for (Iris iris : rowDataset) {
iris.sepalLength = (iris.sepalLength - minSepalLength) / (maxSepalLength - minSepalLength);
iris.sepalWidth = (iris.sepalWidth - minSepalWidth) / (maxSepalWidth - minSepalWidth);
iris.petalLength = (iris.petalLength - minPetalLength) / (maxPetalLength - minPetalLength);
iris.petalWidth = (iris.petalWidth - minPetalWidth) / (maxPetalWidth - minPetalWidth);
}
}

@Override
public Iris mapToDataRecord(String[] csvLine) {
return new Iris(
Double.parseDouble(csvLine[0]),
Double.parseDouble(csvLine[1]),
Double.parseDouble(csvLine[2]),
Double.parseDouble(csvLine[3]),
csvLine[4]
);
}

@Override
public double[][] getInputs(List<Iris> dataset) {
double[][] inputs = new double[dataset.size()][4];
for (int i = 0; i < dataset.size(); i++) {
inputs[i][0] = dataset.get(i).sepalLength;
inputs[i][1] = dataset.get(i).sepalWidth;
inputs[i][2] = dataset.get(i).petalLength;
inputs[i][3] = dataset.get(i).petalWidth;
}
return inputs;
}

@Override
public double[][] getTargets(List<Iris> dataset) {
targets = new double[dataset.size()][3];
for (int i = 0; i < dataset.size(); i++) {
switch (dataset.get(i).variety) {
case "Setosa":
targets[i][0] = 1;
break;
case "Versicolor":
targets[i][1] = 1;
break;
case "Virginica":
targets[i][2] = 1;
break;
}
}
return targets;
}

@Override
public double[][] getTrainFeatures() {
return featuresOf(getSplitedDataset().trainData());
}

@Override
public double[][] getTrainLabels() {
return labelsOf(getSplitedDataset().trainData());
}

@Override
public double[][] getTestLabels() {
return labelsOf(getSplitedDataset().testData());
}

@Override
public double[][] getTestFeatures() {
return featuresOf(getSplitedDataset().testData());
}

private double[][] featuresOf(List<Iris> testData) {
double[][] features = new double[testData.size()][4];
for (int i = 0; i < testData.size(); i++) {
features[i][0] = testData.get(i).sepalLength;
features[i][1] = testData.get(i).sepalWidth;
features[i][2] = testData.get(i).petalLength;
features[i][3] = testData.get(i).petalWidth;
}
return features;
}

private double[][] labelsOf(List<Iris> data) {
double[][] labels = new double[data.size()][3];
for (int i = 0; i < data.size(); i++) {
if (data.get(i).variety.equals("Setosa")) {
labels[i][0] = 1;
labels[i][1] = 0;
labels[i][2] = 0;
}
if (data.get(i).variety.equals("Versicolor")) {
labels[i][0] = 0;
labels[i][1] = 1;
labels[i][2] = 0;
}
if (data.get(i).variety.equals("Virginica")) {
labels[i][0] = 0;
labels[i][1] = 0;
labels[i][2] = 1;
}
}
return labels;
}

@Override
public String getDatasetDescription() {
return "Iris dataset";
}


}
10 changes: 5 additions & 5 deletions example/src/main/java/de/example/data/iris/IrisProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ public double[][] getTrainLabels() {
return labelsOf(trainingData);
}


@Override
public double[][] getTestLabels() {
return labelsOf(testData);
}

@Override
public double[][] getTestFeatures() {
Expand Down Expand Up @@ -100,10 +103,7 @@ private double[][] labelsOf(List<Iris> data) {
}
return labels;
}
@Override
public double[][] getTestLabels() {
return labelsOf(testData);
}


@Override
public String getDescription() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package de.example.data.seaborn;

import de.edux.data.provider.DataUtil;
import de.edux.data.provider.DataProcessor;

import java.util.ArrayList;
import java.util.List;

public class SeabornDataProcessor extends DataUtil<Penguin> {
public class SeabornDataProcessor extends DataProcessor<Penguin> {
@Override
public void normalize(List<Penguin> penguins) {
double maxBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).max().orElse(1);
Expand Down Expand Up @@ -118,4 +118,67 @@ public double[][] getTargets(List<Penguin> dataset) {

return targets;
}

@Override
public String getDatasetDescription() {
return "Seaborn penguins dataset";
}

@Override
public double[][] getTrainFeatures() {
return featuresOf(getSplitedDataset().trainData());
}

@Override
public double[][] getTrainLabels() {
return labelsOf(getSplitedDataset().trainData());
}

@Override
public double[][] getTestFeatures() {
return featuresOf(getSplitedDataset().testData());
}

@Override
public double[][] getTestLabels() {
return labelsOf(getSplitedDataset().testData());
}

private double[][] featuresOf(List<Penguin> data) {
double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften

for (int i = 0; i < data.size(); i++) {
Penguin p = data.get(i);
features[i][0] = p.billLengthMm();
features[i][1] = p.billDepthMm();
features[i][2] = p.flipperLengthMm();
features[i][3] = p.bodyMassG();
}

return features;
}

private double[][] labelsOf(List<Penguin> data) {
double[][] labels = new double[data.size()][3]; // 3 Pinguinarten

for (int i = 0; i < data.size(); i++) {
Penguin p = data.get(i);
switch (p.species().toLowerCase()) {
case "adelie":
labels[i] = new double[]{1.0, 0.0, 0.0};
break;
case "chinstrap":
labels[i] = new double[]{0.0, 1.0, 0.0};
break;
case "gentoo":
labels[i] = new double[]{0.0, 0.0, 1.0};
break;
default:
throw new IllegalArgumentException("Unbekannte Pinguinart: " + p.species());
}
}

return labels;
}

}
Loading

0 comments on commit f0c0ae8

Please sign in to comment.