-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from Samyssmile/feature/unified-classifier
Feature/unified classifier
- Loading branch information
Showing
32 changed files
with
1,175 additions
and
641 deletions.
There are no files selected for viewing
143 changes: 143 additions & 0 deletions
143
example/src/main/java/de/example/benchmark/Benchmark.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
package de.example.benchmark; | ||
|
||
import de.edux.api.Classifier; | ||
import de.edux.functions.activation.ActivationFunction; | ||
import de.edux.functions.initialization.Initialization; | ||
import de.edux.functions.loss.LossFunction; | ||
import de.edux.ml.decisiontree.DecisionTree; | ||
import de.edux.ml.knn.KnnClassifier; | ||
import de.edux.ml.nn.config.NetworkConfiguration; | ||
import de.edux.ml.nn.network.MultilayerPerceptron; | ||
import de.edux.ml.randomforest.RandomForest; | ||
import de.edux.ml.svm.SVMKernel; | ||
import de.edux.ml.svm.SupportVectorMachine; | ||
import de.example.data.seaborn.Penguin; | ||
import de.example.data.seaborn.SeabornDataProcessor; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Map; | ||
import java.io.File; | ||
import java.util.List; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
import java.util.stream.DoubleStream; | ||
import java.util.stream.IntStream; | ||
|
||
/** | ||
* Compare the performance of different classifiers | ||
*/ | ||
public class Benchmark { | ||
private static final boolean SHUFFLE = true; | ||
private static final boolean NORMALIZE = true; | ||
private static final boolean FILTER_INCOMPLETE_RECORDS = true; | ||
private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; | ||
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); | ||
private double[][] trainFeatures; | ||
private double[][] trainLabels; | ||
private double[][] testFeatures; | ||
private double[][] testLabels; | ||
private MultilayerPerceptron multilayerPerceptron; | ||
private NetworkConfiguration networkConfiguration; | ||
|
||
public static void main(String[] args) { | ||
new Benchmark().run(); | ||
} | ||
|
||
private void run() { | ||
initFeaturesAndLabels(); | ||
|
||
Classifier knn = new KnnClassifier(2); | ||
Classifier decisionTree = new DecisionTree(8, 2, 1, 3); | ||
Classifier randomForest = new RandomForest(100, 10, 2, 1, 3, 60); | ||
Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 1); | ||
|
||
networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128,256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); | ||
multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); | ||
Map<String, Classifier> classifiers = Map.of( | ||
"KNN", knn, | ||
"DecisionTree", decisionTree, | ||
"RandomForest", randomForest, | ||
"SVM", svm, | ||
"MLP", multilayerPerceptron | ||
); | ||
|
||
Map<String, List<Double>> results = new ConcurrentHashMap<>(); | ||
results.put("KNN", new ArrayList<>()); | ||
results.put("DecisionTree", new ArrayList<>()); | ||
results.put("RandomForest", new ArrayList<>()); | ||
results.put("SVM", new ArrayList<>()); | ||
results.put("MLP", new ArrayList<>()); | ||
|
||
|
||
IntStream.range(0, 50).forEach(i -> { | ||
knn.train(trainFeatures, trainLabels); | ||
decisionTree.train(trainFeatures, trainLabels); | ||
randomForest.train(trainFeatures, trainLabels); | ||
svm.train(trainFeatures, trainLabels); | ||
multilayerPerceptron.train(trainFeatures, trainLabels); | ||
|
||
double knnAccuracy = knn.evaluate(testFeatures, testLabels); | ||
double decisionTreeAccuracy = decisionTree.evaluate(testFeatures, testLabels); | ||
double randomForestAccuracy = randomForest.evaluate(testFeatures, testLabels); | ||
double svmAccuracy = svm.evaluate(testFeatures, testLabels); | ||
double multilayerPerceptronAccuracy = multilayerPerceptron.evaluate(testFeatures, testLabels); | ||
|
||
results.get("KNN").add(knnAccuracy); | ||
results.get("DecisionTree").add(decisionTreeAccuracy); | ||
results.get("RandomForest").add(randomForestAccuracy); | ||
results.get("SVM").add(svmAccuracy); | ||
results.get("MLP").add(multilayerPerceptronAccuracy); | ||
initFeaturesAndLabels(); | ||
updateMLP(testFeatures, testLabels); | ||
}); | ||
|
||
|
||
//Sort and print results with numeration begin with best average accuracy | ||
System.out.println("Classifier performances (sorted by average accuracy):"); | ||
results.entrySet().stream() | ||
.map(entry -> { | ||
double avgAccuracy = entry.getValue().stream() | ||
.mapToDouble(Double::doubleValue) | ||
.average() | ||
.orElse(0.0); | ||
return Map.entry(entry.getKey(), avgAccuracy); | ||
}) | ||
.sorted(Map.Entry.<String, Double>comparingByValue().reversed()) | ||
.forEachOrdered(entry -> { | ||
System.out.printf("%s: %.2f%%\n", entry.getKey(), entry.getValue() * 100); | ||
}); | ||
|
||
// Additionally, if you want to show other metrics, such as minimum or maximum accuracy, you can calculate and display them similarly. | ||
System.out.println("\nClassifier best and worst performances:"); | ||
results.forEach((classifierName, accuracies) -> { | ||
double maxAccuracy = accuracies.stream() | ||
.mapToDouble(Double::doubleValue) | ||
.max() | ||
.orElse(0.0); | ||
double minAccuracy = accuracies.stream() | ||
.mapToDouble(Double::doubleValue) | ||
.min() | ||
.orElse(0.0); | ||
System.out.printf("%s: Best: %.2f%%, Worst: %.2f%%\n", classifierName, maxAccuracy * 100, minAccuracy * 100); | ||
}); | ||
|
||
|
||
} | ||
|
||
private void updateMLP(double[][] testFeatures, double[][] testLabels) { | ||
multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); | ||
} | ||
|
||
private void initFeaturesAndLabels() { | ||
var seabornDataProcessor = new SeabornDataProcessor(); | ||
List<Penguin> data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); | ||
seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); | ||
|
||
trainFeatures = seabornDataProcessor.getTrainFeatures(); | ||
trainLabels = seabornDataProcessor.getTrainLabels(); | ||
|
||
testFeatures = seabornDataProcessor.getTestFeatures(); | ||
testLabels = seabornDataProcessor.getTestLabels(); | ||
|
||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
133 changes: 133 additions & 0 deletions
133
example/src/main/java/de/example/data/iris/IrisDataProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package de.example.data.iris; | ||
|
||
import de.edux.data.provider.DataProcessor; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.util.List; | ||
|
||
public class IrisDataProcessor extends DataProcessor<Iris> { | ||
private static final Logger LOG = LoggerFactory.getLogger(IrisDataProcessor.class); | ||
private double[][] targets; | ||
|
||
@Override | ||
public void normalize(List<Iris> rowDataset) { | ||
double minSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).min().getAsDouble(); | ||
double maxSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).max().getAsDouble(); | ||
double minSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).min().getAsDouble(); | ||
double maxSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).max().getAsDouble(); | ||
double minPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).min().getAsDouble(); | ||
double maxPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).max().getAsDouble(); | ||
double minPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).min().getAsDouble(); | ||
double maxPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).max().getAsDouble(); | ||
|
||
for (Iris iris : rowDataset) { | ||
iris.sepalLength = (iris.sepalLength - minSepalLength) / (maxSepalLength - minSepalLength); | ||
iris.sepalWidth = (iris.sepalWidth - minSepalWidth) / (maxSepalWidth - minSepalWidth); | ||
iris.petalLength = (iris.petalLength - minPetalLength) / (maxPetalLength - minPetalLength); | ||
iris.petalWidth = (iris.petalWidth - minPetalWidth) / (maxPetalWidth - minPetalWidth); | ||
} | ||
} | ||
|
||
@Override | ||
public Iris mapToDataRecord(String[] csvLine) { | ||
return new Iris( | ||
Double.parseDouble(csvLine[0]), | ||
Double.parseDouble(csvLine[1]), | ||
Double.parseDouble(csvLine[2]), | ||
Double.parseDouble(csvLine[3]), | ||
csvLine[4] | ||
); | ||
} | ||
|
||
@Override | ||
public double[][] getInputs(List<Iris> dataset) { | ||
double[][] inputs = new double[dataset.size()][4]; | ||
for (int i = 0; i < dataset.size(); i++) { | ||
inputs[i][0] = dataset.get(i).sepalLength; | ||
inputs[i][1] = dataset.get(i).sepalWidth; | ||
inputs[i][2] = dataset.get(i).petalLength; | ||
inputs[i][3] = dataset.get(i).petalWidth; | ||
} | ||
return inputs; | ||
} | ||
|
||
@Override | ||
public double[][] getTargets(List<Iris> dataset) { | ||
targets = new double[dataset.size()][3]; | ||
for (int i = 0; i < dataset.size(); i++) { | ||
switch (dataset.get(i).variety) { | ||
case "Setosa": | ||
targets[i][0] = 1; | ||
break; | ||
case "Versicolor": | ||
targets[i][1] = 1; | ||
break; | ||
case "Virginica": | ||
targets[i][2] = 1; | ||
break; | ||
} | ||
} | ||
return targets; | ||
} | ||
|
||
@Override | ||
public double[][] getTrainFeatures() { | ||
return featuresOf(getSplitedDataset().trainData()); | ||
} | ||
|
||
@Override | ||
public double[][] getTrainLabels() { | ||
return labelsOf(getSplitedDataset().trainData()); | ||
} | ||
|
||
@Override | ||
public double[][] getTestLabels() { | ||
return labelsOf(getSplitedDataset().testData()); | ||
} | ||
|
||
@Override | ||
public double[][] getTestFeatures() { | ||
return featuresOf(getSplitedDataset().testData()); | ||
} | ||
|
||
private double[][] featuresOf(List<Iris> testData) { | ||
double[][] features = new double[testData.size()][4]; | ||
for (int i = 0; i < testData.size(); i++) { | ||
features[i][0] = testData.get(i).sepalLength; | ||
features[i][1] = testData.get(i).sepalWidth; | ||
features[i][2] = testData.get(i).petalLength; | ||
features[i][3] = testData.get(i).petalWidth; | ||
} | ||
return features; | ||
} | ||
|
||
private double[][] labelsOf(List<Iris> data) { | ||
double[][] labels = new double[data.size()][3]; | ||
for (int i = 0; i < data.size(); i++) { | ||
if (data.get(i).variety.equals("Setosa")) { | ||
labels[i][0] = 1; | ||
labels[i][1] = 0; | ||
labels[i][2] = 0; | ||
} | ||
if (data.get(i).variety.equals("Versicolor")) { | ||
labels[i][0] = 0; | ||
labels[i][1] = 1; | ||
labels[i][2] = 0; | ||
} | ||
if (data.get(i).variety.equals("Virginica")) { | ||
labels[i][0] = 0; | ||
labels[i][1] = 0; | ||
labels[i][2] = 1; | ||
} | ||
} | ||
return labels; | ||
} | ||
|
||
@Override | ||
public String getDatasetDescription() { | ||
return "Iris dataset"; | ||
} | ||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.