diff --git a/example/src/main/java/de/example/benchmark/Benchmark.java b/example/src/main/java/de/example/benchmark/Benchmark.java new file mode 100644 index 0000000..0cfe866 --- /dev/null +++ b/example/src/main/java/de/example/benchmark/Benchmark.java @@ -0,0 +1,143 @@ +package de.example.benchmark; + +import de.edux.api.Classifier; +import de.edux.functions.activation.ActivationFunction; +import de.edux.functions.initialization.Initialization; +import de.edux.functions.loss.LossFunction; +import de.edux.ml.decisiontree.DecisionTree; +import de.edux.ml.knn.KnnClassifier; +import de.edux.ml.nn.config.NetworkConfiguration; +import de.edux.ml.nn.network.MultilayerPerceptron; +import de.edux.ml.randomforest.RandomForest; +import de.edux.ml.svm.SVMKernel; +import de.edux.ml.svm.SupportVectorMachine; +import de.example.data.seaborn.Penguin; +import de.example.data.seaborn.SeabornDataProcessor; + +import java.util.ArrayList; +import java.util.Map; +import java.io.File; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; + +/** + * Compare the performance of different classifiers + */ +public class Benchmark { + private static final boolean SHUFFLE = true; + private static final boolean NORMALIZE = true; + private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; + private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); + private double[][] trainFeatures; + private double[][] trainLabels; + private double[][] testFeatures; + private double[][] testLabels; + private MultilayerPerceptron multilayerPerceptron; + private NetworkConfiguration networkConfiguration; + + public static void main(String[] args) { + new Benchmark().run(); + } + + private void run() { + initFeaturesAndLabels(); + + Classifier knn = new KnnClassifier(2); + Classifier decisionTree = new DecisionTree(8, 2, 1, 3); + Classifier randomForest = new RandomForest(100, 10, 2, 1, 3, 60); + Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 1); + + networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128,256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); + multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); + Map classifiers = Map.of( + "KNN", knn, + "DecisionTree", decisionTree, + "RandomForest", randomForest, + "SVM", svm, + "MLP", multilayerPerceptron + ); + + Map> results = new ConcurrentHashMap<>(); + results.put("KNN", new ArrayList<>()); + results.put("DecisionTree", new ArrayList<>()); + results.put("RandomForest", new ArrayList<>()); + results.put("SVM", new ArrayList<>()); + results.put("MLP", new ArrayList<>()); + + + IntStream.range(0, 50).forEach(i -> { + knn.train(trainFeatures, trainLabels); + decisionTree.train(trainFeatures, trainLabels); + randomForest.train(trainFeatures, trainLabels); + svm.train(trainFeatures, trainLabels); + multilayerPerceptron.train(trainFeatures, trainLabels); + + double knnAccuracy = knn.evaluate(testFeatures, testLabels); + double decisionTreeAccuracy = decisionTree.evaluate(testFeatures, testLabels); + double randomForestAccuracy = randomForest.evaluate(testFeatures, testLabels); + double svmAccuracy = svm.evaluate(testFeatures, testLabels); + double multilayerPerceptronAccuracy = multilayerPerceptron.evaluate(testFeatures, testLabels); + + results.get("KNN").add(knnAccuracy); + results.get("DecisionTree").add(decisionTreeAccuracy); + results.get("RandomForest").add(randomForestAccuracy); + results.get("SVM").add(svmAccuracy); + results.get("MLP").add(multilayerPerceptronAccuracy); + initFeaturesAndLabels(); + updateMLP(testFeatures, testLabels); + }); + + + //Sort and print results with numeration begin with best average accuracy + System.out.println("Classifier performances (sorted by average accuracy):"); + results.entrySet().stream() + .map(entry -> { + double avgAccuracy = entry.getValue().stream() + .mapToDouble(Double::doubleValue) + .average() + .orElse(0.0); + return Map.entry(entry.getKey(), avgAccuracy); + }) + .sorted(Map.Entry.comparingByValue().reversed()) + .forEachOrdered(entry -> { + System.out.printf("%s: %.2f%%\n", entry.getKey(), entry.getValue() * 100); + }); + + // Additionally, if you want to show other metrics, such as minimum or maximum accuracy, you can calculate and display them similarly. + System.out.println("\nClassifier best and worst performances:"); + results.forEach((classifierName, accuracies) -> { + double maxAccuracy = accuracies.stream() + .mapToDouble(Double::doubleValue) + .max() + .orElse(0.0); + double minAccuracy = accuracies.stream() + .mapToDouble(Double::doubleValue) + .min() + .orElse(0.0); + System.out.printf("%s: Best: %.2f%%, Worst: %.2f%%\n", classifierName, maxAccuracy * 100, minAccuracy * 100); + }); + + + } + + private void updateMLP(double[][] testFeatures, double[][] testLabels) { + multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); + } + + private void initFeaturesAndLabels() { + var seabornDataProcessor = new SeabornDataProcessor(); + List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); + + trainFeatures = seabornDataProcessor.getTrainFeatures(); + trainLabels = seabornDataProcessor.getTrainLabels(); + + testFeatures = seabornDataProcessor.getTestFeatures(); + testLabels = seabornDataProcessor.getTestLabels(); + + + } +} diff --git a/example/src/main/java/de/example/data/iris/Iris.java b/example/src/main/java/de/example/data/iris/Iris.java index 6274c6f..a56feb1 100644 --- a/example/src/main/java/de/example/data/iris/Iris.java +++ b/example/src/main/java/de/example/data/iris/Iris.java @@ -1,6 +1,7 @@ package de.example.data.iris; -public class Iris { + +public class Iris{ public double sepalLength; public double sepalWidth; public double petalLength; @@ -28,4 +29,5 @@ public String toString() { public double[] getFeatures() { return new double[]{sepalLength, sepalWidth, petalLength, petalWidth}; } + } diff --git a/example/src/main/java/de/example/data/iris/IrisDataProcessor.java b/example/src/main/java/de/example/data/iris/IrisDataProcessor.java new file mode 100644 index 0000000..bbe9c12 --- /dev/null +++ b/example/src/main/java/de/example/data/iris/IrisDataProcessor.java @@ -0,0 +1,133 @@ +package de.example.data.iris; + +import de.edux.data.provider.DataProcessor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +public class IrisDataProcessor extends DataProcessor { + private static final Logger LOG = LoggerFactory.getLogger(IrisDataProcessor.class); + private double[][] targets; + + @Override + public void normalize(List rowDataset) { + double minSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).min().getAsDouble(); + double maxSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).max().getAsDouble(); + double minSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).min().getAsDouble(); + double maxSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).max().getAsDouble(); + double minPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).min().getAsDouble(); + double maxPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).max().getAsDouble(); + double minPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).min().getAsDouble(); + double maxPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).max().getAsDouble(); + + for (Iris iris : rowDataset) { + iris.sepalLength = (iris.sepalLength - minSepalLength) / (maxSepalLength - minSepalLength); + iris.sepalWidth = (iris.sepalWidth - minSepalWidth) / (maxSepalWidth - minSepalWidth); + iris.petalLength = (iris.petalLength - minPetalLength) / (maxPetalLength - minPetalLength); + iris.petalWidth = (iris.petalWidth - minPetalWidth) / (maxPetalWidth - minPetalWidth); + } + } + + @Override + public Iris mapToDataRecord(String[] csvLine) { + return new Iris( + Double.parseDouble(csvLine[0]), + Double.parseDouble(csvLine[1]), + Double.parseDouble(csvLine[2]), + Double.parseDouble(csvLine[3]), + csvLine[4] + ); + } + + @Override + public double[][] getInputs(List dataset) { + double[][] inputs = new double[dataset.size()][4]; + for (int i = 0; i < dataset.size(); i++) { + inputs[i][0] = dataset.get(i).sepalLength; + inputs[i][1] = dataset.get(i).sepalWidth; + inputs[i][2] = dataset.get(i).petalLength; + inputs[i][3] = dataset.get(i).petalWidth; + } + return inputs; + } + + @Override + public double[][] getTargets(List dataset) { + targets = new double[dataset.size()][3]; + for (int i = 0; i < dataset.size(); i++) { + switch (dataset.get(i).variety) { + case "Setosa": + targets[i][0] = 1; + break; + case "Versicolor": + targets[i][1] = 1; + break; + case "Virginica": + targets[i][2] = 1; + break; + } + } + return targets; + } + + @Override + public double[][] getTrainFeatures() { + return featuresOf(getSplitedDataset().trainData()); + } + + @Override + public double[][] getTrainLabels() { + return labelsOf(getSplitedDataset().trainData()); + } + + @Override + public double[][] getTestLabels() { + return labelsOf(getSplitedDataset().testData()); + } + + @Override + public double[][] getTestFeatures() { + return featuresOf(getSplitedDataset().testData()); + } + + private double[][] featuresOf(List testData) { + double[][] features = new double[testData.size()][4]; + for (int i = 0; i < testData.size(); i++) { + features[i][0] = testData.get(i).sepalLength; + features[i][1] = testData.get(i).sepalWidth; + features[i][2] = testData.get(i).petalLength; + features[i][3] = testData.get(i).petalWidth; + } + return features; + } + + private double[][] labelsOf(List data) { + double[][] labels = new double[data.size()][3]; + for (int i = 0; i < data.size(); i++) { + if (data.get(i).variety.equals("Setosa")) { + labels[i][0] = 1; + labels[i][1] = 0; + labels[i][2] = 0; + } + if (data.get(i).variety.equals("Versicolor")) { + labels[i][0] = 0; + labels[i][1] = 1; + labels[i][2] = 0; + } + if (data.get(i).variety.equals("Virginica")) { + labels[i][0] = 0; + labels[i][1] = 0; + labels[i][2] = 1; + } + } + return labels; + } + + @Override + public String getDatasetDescription() { + return "Iris dataset"; + } + + +} diff --git a/example/src/main/java/de/example/data/iris/IrisProvider.java b/example/src/main/java/de/example/data/iris/IrisProvider.java index f700be1..cf1b0d0 100644 --- a/example/src/main/java/de/example/data/iris/IrisProvider.java +++ b/example/src/main/java/de/example/data/iris/IrisProvider.java @@ -62,7 +62,10 @@ public double[][] getTrainLabels() { return labelsOf(trainingData); } - + @Override + public double[][] getTestLabels() { + return labelsOf(testData); + } @Override public double[][] getTestFeatures() { @@ -100,10 +103,7 @@ private double[][] labelsOf(List data) { } return labels; } - @Override - public double[][] getTestLabels() { - return labelsOf(testData); - } + @Override public String getDescription() { diff --git a/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java b/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java index 323c007..124b6f3 100644 --- a/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java +++ b/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java @@ -1,11 +1,11 @@ package de.example.data.seaborn; -import de.edux.data.provider.DataUtil; +import de.edux.data.provider.DataProcessor; import java.util.ArrayList; import java.util.List; -public class SeabornDataProcessor extends DataUtil { +public class SeabornDataProcessor extends DataProcessor { @Override public void normalize(List penguins) { double maxBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).max().orElse(1); @@ -118,4 +118,67 @@ public double[][] getTargets(List dataset) { return targets; } + + @Override + public String getDatasetDescription() { + return "Seaborn penguins dataset"; + } + + @Override + public double[][] getTrainFeatures() { + return featuresOf(getSplitedDataset().trainData()); + } + + @Override + public double[][] getTrainLabels() { + return labelsOf(getSplitedDataset().trainData()); + } + + @Override + public double[][] getTestFeatures() { + return featuresOf(getSplitedDataset().testData()); + } + + @Override + public double[][] getTestLabels() { + return labelsOf(getSplitedDataset().testData()); + } + + private double[][] featuresOf(List data) { + double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften + + for (int i = 0; i < data.size(); i++) { + Penguin p = data.get(i); + features[i][0] = p.billLengthMm(); + features[i][1] = p.billDepthMm(); + features[i][2] = p.flipperLengthMm(); + features[i][3] = p.bodyMassG(); + } + + return features; + } + + private double[][] labelsOf(List data) { + double[][] labels = new double[data.size()][3]; // 3 Pinguinarten + + for (int i = 0; i < data.size(); i++) { + Penguin p = data.get(i); + switch (p.species().toLowerCase()) { + case "adelie": + labels[i] = new double[]{1.0, 0.0, 0.0}; + break; + case "chinstrap": + labels[i] = new double[]{0.0, 1.0, 0.0}; + break; + case "gentoo": + labels[i] = new double[]{0.0, 0.0, 1.0}; + break; + default: + throw new IllegalArgumentException("Unbekannte Pinguinart: " + p.species()); + } + } + + return labels; + } + } \ No newline at end of file diff --git a/example/src/main/java/de/example/data/seaborn/SeabornProvider.java b/example/src/main/java/de/example/data/seaborn/SeabornProvider.java index 8f354bb..3c370d6 100644 --- a/example/src/main/java/de/example/data/seaborn/SeabornProvider.java +++ b/example/src/main/java/de/example/data/seaborn/SeabornProvider.java @@ -49,6 +49,19 @@ public Penguin getRandom(boolean equalDistribution) { public double[][] getTrainFeatures() { return featuresOf(trainingData); } + @Override + public double[][] getTrainLabels() { + return labelsOf(trainingData); + } + @Override + public double[][] getTestFeatures() { + return featuresOf(testData); + } + + @Override + public double[][] getTestLabels() { + return labelsOf(testData); + } private double[][] featuresOf(List data) { double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften @@ -63,13 +76,6 @@ private double[][] featuresOf(List data) { return features; } - - - @Override - public double[][] getTrainLabels() { - return labelsOf(trainingData); - } - private double[][] labelsOf(List data) { double[][] labels = new double[data.size()][3]; // 3 Pinguinarten @@ -93,15 +99,7 @@ private double[][] labelsOf(List data) { return labels; } - @Override - public double[][] getTestFeatures() { - return featuresOf(testData); - } - @Override - public double[][] getTestLabels() { - return labelsOf(testData); - } @Override public String getDescription() { diff --git a/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java b/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java index 9b03f7a..3c856cd 100644 --- a/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java +++ b/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java @@ -1,11 +1,9 @@ package de.example.decisiontree; import de.edux.ml.decisiontree.DecisionTree; -import de.edux.ml.decisiontree.IDecisionTree; import de.example.data.iris.IrisProvider; -import java.util.Arrays; -import static de.edux.util.LabelDimensionConverter.convert2DLabelArrayTo1DLabelArray; +import java.util.Arrays; public class DecisionTreeExample { private static final boolean SHUFFLE = true; @@ -21,17 +19,13 @@ public static void main(String[] args) { double[][] labels = datasetProvider.getTrainLabels(); // Train Decision Tree - IDecisionTree decisionTree = new DecisionTree(); - decisionTree.train(features, labels, 6, 2, 1, 4); + DecisionTree decisionTree = new DecisionTree(8, 2, 1, 4); + decisionTree.train(features, labels); // Evaluate Decision Tree double[][] testFeatures = datasetProvider.getTestFeatures(); double[][] testLabels = datasetProvider.getTestLabels(); decisionTree.evaluate(testFeatures, testLabels); - // Get Feature Importance - double[] featureImportance = decisionTree.getFeatureImportance(); - System.out.println("Feature Importance: " + Arrays.toString(featureImportance)); } - } diff --git a/example/src/main/java/de/example/knn/KnnIrisExample.java b/example/src/main/java/de/example/knn/KnnIrisExample.java index 999b526..0e5d536 100644 --- a/example/src/main/java/de/example/knn/KnnIrisExample.java +++ b/example/src/main/java/de/example/knn/KnnIrisExample.java @@ -1,12 +1,12 @@ package de.example.knn; -import de.edux.ml.knn.ILabeledPoint; +import de.edux.api.Classifier; import de.edux.ml.knn.KnnClassifier; -import de.edux.ml.knn.KnnPoint; +import de.edux.ml.nn.network.api.Dataset; import de.example.data.iris.Iris; -import de.example.data.iris.IrisProvider; +import de.example.data.iris.IrisDataProcessor; -import java.util.ArrayList; +import java.io.File; import java.util.List; /** @@ -17,28 +17,18 @@ public class KnnIrisExample { private static final boolean SHUFFLE = true; private static final boolean NORMALIZE = true; + private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; + private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); public static void main(String[] args) { - var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.6); - datasetProvider.printStatistics(); - - List labeledPoints = new ArrayList<>(); - for (int i = 0; i < datasetProvider.getTrainFeatures().length; i++) { - labeledPoints.add(new KnnPoint(datasetProvider.getTrainFeatures()[i], datasetProvider.getTrainData().get(i).variety)); - } - - KnnClassifier knnClassifier = new KnnClassifier(1, labeledPoints); - - // Evaluate on test data - // transfer Iris to KnnPoint - List testDataset = datasetProvider.getTestData(); - List testLabeledPoints = new ArrayList<>(); - testDataset.forEach(iris -> { - ILabeledPoint labeledPoint = new KnnPoint(iris.getFeatures(), iris.variety); - testLabeledPoints.add(labeledPoint); - }); - - //Evaluate - knnClassifier.evaluate(testLabeledPoints); + var irisDataProcessor = new IrisDataProcessor(); + List data = irisDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + irisDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); + + Classifier knn = new KnnClassifier(2); + //Train and evaluate + knn.train(irisDataProcessor.getTrainFeatures(), irisDataProcessor.getTrainLabels()); + knn.evaluate(irisDataProcessor.getTestFeatures(), irisDataProcessor.getTestLabels()); } } \ No newline at end of file diff --git a/example/src/main/java/de/example/knn/KnnSeabornExample.java b/example/src/main/java/de/example/knn/KnnSeabornExample.java index 30476fb..faf4f90 100644 --- a/example/src/main/java/de/example/knn/KnnSeabornExample.java +++ b/example/src/main/java/de/example/knn/KnnSeabornExample.java @@ -1,14 +1,13 @@ package de.example.knn; -import de.edux.ml.knn.ILabeledPoint; +import de.edux.api.Classifier; import de.edux.ml.knn.KnnClassifier; -import de.edux.ml.knn.KnnPoint; +import de.edux.ml.nn.network.api.Dataset; import de.example.data.seaborn.Penguin; import de.example.data.seaborn.SeabornDataProcessor; import de.example.data.seaborn.SeabornProvider; import java.io.File; -import java.util.ArrayList; import java.util.List; /** @@ -19,34 +18,21 @@ public class KnnSeabornExample { private static final boolean SHUFFLE = true; private static final boolean NORMALIZE = true; private static final boolean FILTER_INCOMPLETE_RECORDS = true; - private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; + private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); public static void main(String[] args) { - //Load dataset + //Load Data, shuffle, normalize, filter incomplete records out. var seabornDataProcessor = new SeabornDataProcessor(); - List dataset = seabornDataProcessor.loadTDataSet(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - List> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); - var seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1)); + List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + //Split dataset into train and test + Dataset dataset = seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); + var seabornProvider = new SeabornProvider(data, dataset.trainData(), dataset.testData()); seabornProvider.printStatistics(); - - // Train classifier - List labeledPoints = new ArrayList<>(); - for (int i = 0; i < seabornProvider.getTrainFeatures().length; i++) { - labeledPoints.add(new KnnPoint(seabornProvider.getTrainFeatures()[i], seabornProvider.getTrainData().get(i).species())); - } - - KnnClassifier knnClassifier = new KnnClassifier(1, labeledPoints); - - // Evaluate classifier - List testDataset = seabornProvider.getTestData(); - List testLabeledPoints = new ArrayList<>(); - testDataset.forEach(penguin -> { - ILabeledPoint labeledPoint = new KnnPoint(penguin.getFeatures(), penguin.species()); - testLabeledPoints.add(labeledPoint); - }); - - knnClassifier.evaluate(testLabeledPoints); + Classifier knn = new KnnClassifier(2); + //Train and evaluate + knn.train(seabornProvider.getTrainFeatures(), seabornProvider.getTrainLabels()); + knn.evaluate(seabornProvider.getTestFeatures(), seabornProvider.getTestLabels()); } } diff --git a/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java b/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java index a064c01..4e90069 100644 --- a/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java +++ b/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java @@ -38,10 +38,11 @@ public static void main(String[] args) { // - Categorical Cross Entropy as Loss Function // - Xavier as Weight Initialization for Hidden Layers // - Xavier as Weight Initialization for Output Layer - NetworkConfiguration networkConfiguration = new NetworkConfiguration(features[0].length, List.of(12, 6), 3, 0.01, 1000, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); + NetworkConfiguration networkConfiguration = new NetworkConfiguration(features[0].length, List.of(32, 6), 3, 0.01, 1000, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); - MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(features, labels, testFeatures, testLabels, networkConfiguration); - multilayerPerceptron.train(); + MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); + multilayerPerceptron.train(features, labels); + multilayerPerceptron.evaluate(testFeatures, testLabels); } } diff --git a/example/src/main/java/de/example/svm/SVMExample.java b/example/src/main/java/de/example/svm/SVMExample.java index 80a194a..f245fcd 100644 --- a/example/src/main/java/de/example/svm/SVMExample.java +++ b/example/src/main/java/de/example/svm/SVMExample.java @@ -1,5 +1,6 @@ package de.example.svm; +import de.edux.api.Classifier; import de.edux.ml.svm.ISupportVectorMachine; import de.edux.ml.svm.SVMKernel; import de.edux.ml.svm.SupportVectorMachine; @@ -14,31 +15,20 @@ public static void main(String[] args){ datasetProvider.printStatistics(); //Get Features and Labels - double[][] features = datasetProvider.getTrainFeatures(); + var features = datasetProvider.getTrainFeatures(); // 1 - SATOSA 2 - VERSICOLOR 3 - VIRGINICA - int[] labels = convert2DLabelArrayTo1DLabelArray(datasetProvider.getTrainLabels()); + var labels = datasetProvider.getTrainLabels(); - ISupportVectorMachine supportVectorMachine = new SupportVectorMachine(SVMKernel.LINEAR, 1); + Classifier supportVectorMachine = new SupportVectorMachine(SVMKernel.LINEAR, 1); //ONEvsONE Strategy supportVectorMachine.train(features, labels); double[][] testFeatures = datasetProvider.getTestFeatures(); double[][] testLabels = datasetProvider.getTestLabels(); - int[] decisionTreeTestLabels = convert2DLabelArrayTo1DLabelArray(testLabels); - supportVectorMachine.evaluate(testFeatures, decisionTreeTestLabels); + supportVectorMachine.evaluate(testFeatures, testLabels); } - private static int[] convert2DLabelArrayTo1DLabelArray(double[][] labels) { - int[] decisionTreeTrainLabels = new int[labels.length]; - for (int i = 0; i < labels.length; i++) { - for (int j = 0; j < labels[i].length; j++) { - if (labels[i][j] == 1) { - decisionTreeTrainLabels[i] = (j+1); - } - } - } - return decisionTreeTrainLabels; - } + } diff --git a/lib/src/main/java/de/edux/api/Classifier.java b/lib/src/main/java/de/edux/api/Classifier.java new file mode 100644 index 0000000..d158480 --- /dev/null +++ b/lib/src/main/java/de/edux/api/Classifier.java @@ -0,0 +1,47 @@ +package de.edux.api; + +public interface Classifier { + + /** + * Trains the model using the provided training inputs and targets. + * @param features 2D array of double, where each inner array represents + * @param labels 2D array of double, where each inner array represents + * @return true if the model was successfully trained, false otherwise. + */ + boolean train(double[][] features, double[][] labels); + /** + * Evaluates the model's performance against the provided test inputs and targets. + * + * This method takes a set of test inputs and their corresponding expected targets, + * applies the model to predict the outputs for the inputs, and then compares + * the predicted outputs to the expected targets to evaluate the performance + * of the model. The nature and metric of the evaluation (e.g., accuracy, MSE, etc.) + * are dependent on the specific implementation within the method. + * + * @param testInputs 2D array of double, where each inner array represents + * a single set of input values to be evaluated by the model. + * @param testTargets 2D array of double, where each inner array represents + * the expected output or target for the corresponding set + * of inputs in {@code testInputs}. + * @return a double value representing the performance of the model when evaluated + * against the provided test inputs and targets. The interpretation of this + * value (e.g., higher is better, lower is better, etc.) depends on the + * specific evaluation metric being used. + * @throws IllegalArgumentException if the lengths of {@code testInputs} and + * {@code testTargets} do not match, or if + * they are empty. + */ + double evaluate(double[][] testInputs, double[][] testTargets); + + + /** + * Predicts the output for a single set of input values. + * + * @param feature a single set of input values to be evaluated by the model. + * @return a double array representing the predicted output values for the + * provided input values. + * @throws IllegalArgumentException if {@code feature} is empty. + */ + public double[] predict(double[] feature); + +} diff --git a/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java b/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java index 649d4a8..263276c 100644 --- a/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java +++ b/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java @@ -4,8 +4,21 @@ public abstract class DataPostProcessor { public abstract void normalize(List rowDataset); + public abstract T mapToDataRecord(String[] csvLine); + public abstract double[][] getInputs(List dataset); + public abstract double[][] getTargets(List dataset); + public abstract String getDatasetDescription(); + + public abstract double[][] getTrainFeatures(); + + public abstract double[][] getTrainLabels(); + + public abstract double[][] getTestLabels(); + + public abstract double[][] getTestFeatures(); + } diff --git a/lib/src/main/java/de/edux/data/provider/DataProcessor.java b/lib/src/main/java/de/edux/data/provider/DataProcessor.java new file mode 100644 index 0000000..e1f13bc --- /dev/null +++ b/lib/src/main/java/de/edux/data/provider/DataProcessor.java @@ -0,0 +1,81 @@ +package de.edux.data.provider; + +import de.edux.data.reader.CSVIDataReader; +import de.edux.data.reader.IDataReader; +import de.edux.ml.nn.network.api.Dataset; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public abstract class DataProcessor extends DataPostProcessor implements IDataUtil { + private static final Logger LOG = LoggerFactory.getLogger(DataProcessor.class); + private final IDataReader csvDataReader; + private ArrayList dataset; + private Dataset splitedDataset; + + public DataProcessor() { + this.csvDataReader = new CSVIDataReader(); + } + + public DataProcessor(IDataReader csvDataReader) { + this.csvDataReader = csvDataReader; + } + + @Override + public List loadDataSetFromCSV(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, boolean filterIncompleteRecords) { + List x = csvDataReader.readFile(csvFile, csvSeparator); + List unmodifiableDataset = csvDataReader.readFile(csvFile, csvSeparator) + .stream() + .map(this::mapToDataRecord) + .filter(record -> !filterIncompleteRecords || record != null) + .toList(); + + dataset = new ArrayList<>(unmodifiableDataset); + LOG.info("Dataset loaded"); + + if (normalize) { + normalize(dataset); + LOG.info("Dataset normalized"); + } + + if (shuffle) { + Collections.shuffle(dataset); + LOG.info("Dataset shuffled"); + } + return dataset; + } + + /** + * Split data into train and test data + * + * @param data data to split + * @param trainTestSplitRatio ratio of train data + * @return list of train and test data. First element is train data, second element is test data. + */ + @Override + public Dataset split(List data, double trainTestSplitRatio) { + if (trainTestSplitRatio < 0.0 || trainTestSplitRatio > 1.0) { + throw new IllegalArgumentException("Train-test split ratio must be between 0.0 and 1.0"); + } + + int trainSize = (int) (data.size() * trainTestSplitRatio); + + List trainDataset = data.subList(0, trainSize); + List testDataset = data.subList(trainSize, data.size()); + + splitedDataset = new Dataset<>(trainDataset, testDataset); + return splitedDataset; + } + public ArrayList getDataset() { + return dataset; + } + + public Dataset getSplitedDataset() { + return splitedDataset; + } +} + diff --git a/lib/src/main/java/de/edux/data/provider/DataUtil.java b/lib/src/main/java/de/edux/data/provider/DataUtil.java deleted file mode 100644 index d1c435f..0000000 --- a/lib/src/main/java/de/edux/data/provider/DataUtil.java +++ /dev/null @@ -1,70 +0,0 @@ -package de.edux.data.provider; - -import de.edux.data.reader.CSVIDataReader; -import de.edux.data.reader.IDataReader; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -public abstract class DataUtil extends DataPostProcessor implements IDataUtil { - private static final Logger logger = LoggerFactory.getLogger(DataUtil.class); - private final IDataReader csvDataReader; - public DataUtil() { - this.csvDataReader = new CSVIDataReader(); - } - - public DataUtil(IDataReader csvDataReader) { - this.csvDataReader = csvDataReader; - } - - @Override - public List loadTDataSet(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, boolean filterIncompleteRecords) { - List x = csvDataReader.readFile(csvFile, csvSeparator); - List unmodifiableDataset = csvDataReader.readFile(csvFile, csvSeparator) - .stream() - .map(this::mapToDataRecord) - .filter(record -> !filterIncompleteRecords || record != null) - .toList(); - - List dataset = new ArrayList<>(unmodifiableDataset); - logger.info("Dataset loaded"); - - if (normalize) { - normalize(dataset); - logger.info("Dataset normalized"); - } - - if (shuffle) { - Collections.shuffle(dataset); - logger.info("Dataset shuffled"); - } - return dataset; - } - - /** - * Split dataset into train and test dataset - * - * @param dataset dataset to split - * @param trainTestSplitRatio ratio of train dataset - * @return list of train and test dataset. First element is train dataset, second element is test dataset. - */ - @Override - public List> split(List dataset, double trainTestSplitRatio) { - if (trainTestSplitRatio < 0.0 || trainTestSplitRatio > 1.0) { - throw new IllegalArgumentException("Train-test split ratio must be between 0.0 and 1.0"); - } - - int trainSize = (int) (dataset.size() * trainTestSplitRatio); - - List trainDataset = dataset.subList(0, trainSize); - List testDataset = dataset.subList(trainSize, dataset.size()); - - return List.of(trainDataset, testDataset); - } - -} - diff --git a/lib/src/main/java/de/edux/data/provider/IDataUtil.java b/lib/src/main/java/de/edux/data/provider/IDataUtil.java index 8e4d9d6..1f85516 100644 --- a/lib/src/main/java/de/edux/data/provider/IDataUtil.java +++ b/lib/src/main/java/de/edux/data/provider/IDataUtil.java @@ -1,12 +1,14 @@ package de.edux.data.provider; +import de.edux.ml.nn.network.api.Dataset; + import java.io.File; import java.util.List; public interface IDataUtil { - List loadTDataSet(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, boolean filterIncompleteRecords); + List loadDataSetFromCSV(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, boolean filterIncompleteRecords); - List> split(List dataset, double trainTestSplitRatio); + Dataset split(List dataset, double trainTestSplitRatio); double[][] getInputs(List dataset); diff --git a/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java b/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java index 4d2793e..aa9844b 100644 --- a/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java +++ b/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java @@ -1,243 +1,272 @@ package de.edux.ml.decisiontree; +import de.edux.api.Classifier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.Arrays; -import java.util.Map; +import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; /** - * A decision tree classifier. - *

- * This class implements a binary decision tree algorithm for classification. - * The decision tree is built by recursively splitting the training data based on - * the feature that results in the minimum Gini index, which is a measure of impurity. - *

-*/ -public class DecisionTree implements IDecisionTree { - private static final Logger LOG = LoggerFactory.getLogger(DecisionTree.class); - private Node root; - private int maxDepth; - private int minSamplesSplit; - private int minSamplesLeaf; - private int maxLeafNodes; - - private double calculateGiniIndex(double[] labels) { - if (labels.length == 0) { - return 0.0; + * A Decision Tree classifier for predictive modeling. + * + *

The {@code DecisionTree} class is a binary tree where each node represents a decision + * on a particular feature from the input feature vector, effectively partitioning the + * input space into regions with similar output labels. The tree is built recursively + * by selecting splits that minimize the Gini impurity of the resultant partitions. + * + *

Features: + *

    + *
  • Supports binary classification problems.
  • + *
  • Utilizes the Gini impurity to determine optimal feature splits.
  • + *
  • Enables control over tree depth and complexity through various hyperparameters.
  • + *
+ * + *

Hyperparameters include: + *

    + *
  • {@code maxDepth}: The maximum depth of the tree.
  • + *
  • {@code minSamplesSplit}: The minimum number of samples required to split an internal node.
  • + *
  • {@code minSamplesLeaf}: The minimum number of samples required to be at a leaf node.
  • + *
  • {@code maxLeafNodes}: The maximum number of leaf nodes in the tree.
  • + *
+ * + *

Usage example: + *

{@code
+ * DecisionTree classifier = new DecisionTree(10, 2, 1, 50);
+ * classifier.train(trainingFeatures, trainingLabels);
+ * double accuracy = classifier.evaluate(testFeatures, testLabels);
+ * }
+ * + *

Note: This class requires a thorough validation of input data and parameters, ensuring + * they are never {@code null}, have appropriate dimensions, and adhere to any other + * prerequisites or assumptions, to guarantee robustness and avoid runtime exceptions. + * + * @see Classifier + */ +public class DecisionTree implements Classifier { + private static final Logger LOG = LoggerFactory.getLogger(DecisionTree.class); + private Node root; + private final int maxDepth; + private final int minSamplesSplit; + private final int minSamplesLeaf; + private final int maxLeafNodes; + private final Map featureImportances; + + public DecisionTree(int maxDepth, + int minSamplesSplit, + int minSamplesLeaf, + int maxLeafNodes) { + this.maxDepth = maxDepth; + this.minSamplesSplit = minSamplesSplit; + this.minSamplesLeaf = minSamplesLeaf; + this.maxLeafNodes = maxLeafNodes; + this.featureImportances = new HashMap<>(); } - int[] counts = new int[(int) Arrays.stream(labels).max().getAsDouble() + 1]; - Arrays.stream(labels).forEach(label -> counts[(int) label]++); - return 1.0 - - Arrays.stream(counts) - .mapToDouble(count -> Math.pow((double) count / labels.length, 2)) - .sum(); - } - - private double calculateSplitGiniIndex(double[][] leftData, double[][] rightData) { - double[] leftLabels = Arrays.stream(leftData).mapToDouble(row -> row[row.length - 1]).toArray(); - double[] rightLabels = - Arrays.stream(rightData).mapToDouble(row -> row[row.length - 1]).toArray(); - double leftGiniIndex = calculateGiniIndex(leftLabels); - double rightGiniIndex = calculateGiniIndex(rightLabels); - return leftData.length * leftGiniIndex / (leftData.length + rightData.length) - + rightData.length * rightGiniIndex / (leftData.length + rightData.length); - } - - private double[][] splitData(double[][] data, int column, double value) { - return Arrays.stream(data).filter(row -> row[column] < value).toArray(double[][]::new); - } - - private double[][] getFeatures(double[][] data) { - return Arrays.stream(data) - .map(row -> Arrays.copyOf(row, row.length - 1)) - .toArray(double[][]::new); - } - - private void buildTree(Node node) { - if (node.data.length <= minSamplesSplit - || getDepth(node) >= maxDepth - || getLeafCount(node) >= maxLeafNodes) { - node.isLeaf = true; - return; - } + @Override + public boolean train(double[][] features, double[][] labels) { + try { + if (features == null || labels == null || features.length == 0 || labels.length == 0 || features.length != labels.length) { + LOG.error("Invalid training data"); + return false; + } + + this.root = buildTree(features, labels, 0); - double minGiniIndex = Double.MAX_VALUE; - int minColumn = -1; - double minValue = Double.MAX_VALUE; - double[][] leftData = new double[0][]; - double[][] rightData = new double[0][]; - - for (int column = 0; column < getFeatures(node.data)[0].length; column++) { - for (double[] row : node.data) { - double[][] leftSplitData = splitData(node.data, column, row[column], true); - double[][] rightSplitData = splitData(node.data, column, row[column], false); - double giniIndex = calculateSplitGiniIndex(leftSplitData, rightSplitData); - - if (giniIndex < minGiniIndex && leftSplitData.length >= minSamplesLeaf && rightSplitData.length >= minSamplesLeaf) { - minGiniIndex = giniIndex; - minColumn = column; - minValue = row[column]; - leftData = leftSplitData; - rightData = rightSplitData; + return true; + } catch (Exception e) { + LOG.error("An error occurred during training", e); + return false; } - } } + private Node buildTree(double[][] features, double[][] labels, int depth) { + Node node = new Node(features); + node.predictedLabel = getMajorityLabel(labels); - if (minColumn == -1) { - node.isLeaf = true; - } else { - node.splitFeature = minColumn; - node.value = minValue; - node.left = new Node(leftData); - node.right = new Node(rightData); - buildTree(node.left); - buildTree(node.right); - } - } - - @Override - public void train( - double[][] features, - double[][] labels, - int maxDepth, - int minSamplesSplit, - int minSamplesLeaf, - int maxLeafNodes) { - this.maxDepth = maxDepth; - this.minSamplesSplit = minSamplesSplit; - this.minSamplesLeaf = minSamplesLeaf; - this.maxLeafNodes = maxLeafNodes; - - double[][] data = new double[features.length][]; - for (int i = 0; i < features.length; i++) { - data[i] = Arrays.copyOf(features[i], features[i].length + 1); - data[i][data[i].length - 1] = getIndexOfHighestValue(labels[i]); - } - root = new Node(data); - buildTree(root); - } + boolean maxDepthReached = depth >= maxDepth; + boolean tooFewSamples = features.length < minSamplesSplit; + int currentLeafNodes = 0; + boolean maxLeafNodesReached = currentLeafNodes >= maxLeafNodes; - private double getIndexOfHighestValue(double[] labels) { - if (labels == null || labels.length == 0) { - throw new IllegalArgumentException("Array must not be null or empty"); - } + if (maxDepthReached || tooFewSamples || maxLeafNodesReached) { + return node; + } + + double bestGini = Double.MAX_VALUE; + double[][] bestLeftFeatures = null, bestRightFeatures = null; + double[][] bestLeftLabels = null, bestRightLabels = null; + + for (int featureIndex = 0; featureIndex < features[0].length; featureIndex++) { + for (double[] feature : features) { + double[][] leftFeatures = filterRows(features, featureIndex, feature[featureIndex], true); + double[][] rightFeatures = filterRows(features, featureIndex, feature[featureIndex], false); + + double[][] leftLabels = filterRows(labels, leftFeatures, features); + double[][] rightLabels = filterRows(labels, rightFeatures, features); + + double gini = computeGini(leftLabels, rightLabels); + + if (gini < bestGini) { + bestGini = gini; + updateFeatureImportances(featureIndex, gini); + bestLeftFeatures = leftFeatures; + bestRightFeatures = rightFeatures; + bestLeftLabels = leftLabels; + bestRightLabels = rightLabels; + node.splitFeatureIndex = featureIndex; + node.splitValue = feature[featureIndex]; + } + } + } - int maxIndex = 0; - double maxValue = labels[0]; + if (bestLeftFeatures != null && bestRightFeatures != null && + bestLeftFeatures.length >= minSamplesLeaf && bestRightFeatures.length >= minSamplesLeaf) { - for (int i = 1; i < labels.length; i++) { - if (labels[i] > maxValue) { - maxValue = labels[i]; - maxIndex = i; - } + Node leftChild = buildTree(bestLeftFeatures, bestLeftLabels, depth + 1); + Node rightChild = buildTree(bestRightFeatures, bestRightLabels, depth + 1); + + if(currentLeafNodes + 2 <= maxLeafNodes) { + node.left = leftChild; + node.right = rightChild; + } + } + + return node; + } + + private void updateFeatureImportances(int featureIndex, double giniReduction) { + featureImportances.merge(featureIndex, giniReduction, Double::sum); } - return maxIndex; - } + public Map getFeatureImportances() { + double totalImportance = featureImportances.values().stream().mapToDouble(Double::doubleValue).sum(); + return featureImportances.entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + e -> e.getValue() / totalImportance)); + } - @Override - public double predict(double[] feature) { - return predict(feature, root); - } - private double predict(double[] feature, Node node) { - if (node.isLeaf) { - return getMostCommonLabel(node.data); + private double[][] filterRows(double[][] matrix, int featureIndex, double value, boolean lessThan) { + return Arrays.stream(matrix) + .filter(row -> (lessThan && row[featureIndex] < value) || (!lessThan && row[featureIndex] >= value)) + .toArray(double[][]::new); } - if (feature[node.splitFeature] < node.value) { - return predict(feature, node.left); - } else { - return predict(feature, node.right); + private double[][] filterRows(double[][] labels, double[][] filteredFeatures, double[][] originalFeatures) { + List filteredLabelsList = new ArrayList<>(); + for (double[] filteredFeature : filteredFeatures) { + for (int i = 0; i < originalFeatures.length; i++) { + if (Arrays.equals(filteredFeature, originalFeatures[i])) { + filteredLabelsList.add(labels[i]); + break; + } + } + } + return filteredLabelsList.toArray(new double[0][0]); } - } - - private double getMostCommonLabel(double[][] data) { - return Arrays.stream(data) - .mapToDouble(row -> row[row.length - 1]) - .boxed() - .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())) - .entrySet() - .stream() - .max(Map.Entry.comparingByValue()) - .get() - .getKey(); - } - - @Override - public double evaluate(double[][] features, double[][] labels) { - int correctPredictions = 0; - for (int i = 0; i < features.length; i++) { - double predictedLabel = predict(features[i]); - double actualLabel = getIndexOfHighestValue(labels[i]); - - if (predictedLabel == actualLabel) { - correctPredictions++; - } + + private double[] getMajorityLabel(double[][] labels) { + return Arrays.stream(labels) + .map(Arrays::toString) // Convert double[] to String for grouping + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())) + .entrySet().stream() + .max(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .map(str -> Arrays.stream(str.substring(1, str.length() - 1).split(", ")) + .mapToDouble(Double::parseDouble).toArray()) // Convert String back to double[] + .orElseThrow(RuntimeException::new); } - double accuracy = (double) correctPredictions / features.length; - LOG.info("Model Accuracy: {}%", accuracy * 100); - return accuracy; - } + @Override + public double evaluate(double[][] testInputs, double[][] testTargets) { + if (testInputs == null || testTargets == null || testInputs.length == 0 || testTargets.length == 0 || testInputs.length != testTargets.length) { + LOG.error("Invalid test data"); + return 0; + } + long correctPredictions = 0; + for (int i = 0; i < testInputs.length; i++) { + double[] prediction = predict(testInputs[i]); + if (Arrays.equals(prediction, testTargets[i])) { + correctPredictions++; + } + } - @Override - public double[] getFeatureImportance() { - int numFeatures = root.data[0].length - 1; - double[] importances = new double[numFeatures]; - calculateFeatureImportance(root, importances); - return importances; - } + double accuracy = (double) correctPredictions / testInputs.length; + LOG.info(String.format("Decision Tree - accuracy: %.2f%%", accuracy * 100)); + return accuracy; + } - private double calculateFeatureImportance(Node node, double[] importances) { - if (node == null || node.isLeaf) { - return 0; + @Override + public double[] predict(double[] feature) { + return predictRecursive(root, feature); } - double importance = calculateGiniIndex(getLabels(node.data)) - - calculateSplitGiniIndex(node.left.data, node.right.data); - importances[node.splitFeature] += importance; + private double[] predictRecursive(Node node, double[] feature) { + if (node == null || feature == null) { + throw new IllegalArgumentException("Node and feature cannot be null"); + } + + if (node.left == null && node.right == null) { + return node.predictedLabel; + } + + if (node.splitFeatureIndex >= feature.length) { + throw new IllegalArgumentException("splitFeatureIndex is out of bounds of feature array"); + } - return importance + calculateFeatureImportance(node.left, importances) + calculateFeatureImportance(node.right, importances); - } + if (feature[node.splitFeatureIndex] < node.splitValue) { + if (node.left == null) { + throw new IllegalStateException("Left node is null when trying to traverse left"); + } + return predictRecursive(node.left, feature); + } else { + if (node.right == null) { + throw new IllegalStateException("Right node is null when trying to traverse right"); + } + return predictRecursive(node.right, feature); + } + } - private double[] getLabels(double[][] data) { - return Arrays.stream(data) - .mapToDouble(row -> row[row.length - 1]) - .toArray(); - } - private int getDepth(Node node) { - if (node == null) { - return 0; + private double computeGini(double[][] leftLabels, double[][] rightLabels) { + double leftImpurity = computeImpurity(leftLabels); + double rightImpurity = computeImpurity(rightLabels); + double leftWeight = ((double) leftLabels.length) / (leftLabels.length + rightLabels.length); + double rightWeight = ((double) rightLabels.length) / (leftLabels.length + rightLabels.length); + return leftWeight * leftImpurity + rightWeight * rightImpurity; } - return Math.max(getDepth(node.left), getDepth(node.right)) + 1; - } - private double[][] splitData(double[][] data, int column, double value, boolean isLeftSplit) { - if (isLeftSplit) { - return Arrays.stream(data).filter(row -> row[column] < value).toArray(double[][]::new); - } else { - return Arrays.stream(data).filter(row -> row[column] >= value).toArray(double[][]::new); + private double computeImpurity(double[][] labels) { + double impurity = 1.0; + Map labelCounts = Arrays.stream(labels) + .map(Arrays::toString) + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); + for (Long count : labelCounts.values()) { + double p = ((double) count) / labels.length; + impurity -= p * p; + } + return impurity; } - } - - private int getLeafCount(Node node) { - if (node == null) { - return 0; - } else if (node.isLeaf) { - return 1; - } else { - return getLeafCount(node.left) + getLeafCount(node.right); + + + + static class Node { + double[][] data; + Node left; + Node right; + int splitFeatureIndex; + double splitValue; + double[] predictedLabel; + + public Node(double[][] data) { + this.data = data; + } } - } } \ No newline at end of file diff --git a/lib/src/main/java/de/edux/ml/decisiontree/IDecisionTree.java b/lib/src/main/java/de/edux/ml/decisiontree/IDecisionTree.java index 69b4cc7..090ae3d 100644 --- a/lib/src/main/java/de/edux/ml/decisiontree/IDecisionTree.java +++ b/lib/src/main/java/de/edux/ml/decisiontree/IDecisionTree.java @@ -2,30 +2,6 @@ public interface IDecisionTree { - void train( - double[][] features, - double[][] labels, - int maxDepth, - int minSamplesSplit, - int minSamplesLeaf, - int maxLeafNodes); - - /** - * Pedicts the label for the given features. - * - * @param features the features to predict - * @return the predicted label - */ - double predict(double[] features); - /** - * Evaluates the given features and labels against the decision tree. - * - * @param features the features to evaluate - * @param labels the labels to evaluate - * @return true if the decision tree correctly classified the features and labels, false otherwise - */ - double evaluate(double[][] features, double[][] labels); - /** * Returns the feature importance of the decision tree. * diff --git a/lib/src/main/java/de/edux/ml/knn/ILabeledPoint.java b/lib/src/main/java/de/edux/ml/knn/ILabeledPoint.java deleted file mode 100644 index 5187eb9..0000000 --- a/lib/src/main/java/de/edux/ml/knn/ILabeledPoint.java +++ /dev/null @@ -1,6 +0,0 @@ -package de.edux.ml.knn; - -public interface ILabeledPoint { - double[] getFeatures(); - String getLabel(); -} \ No newline at end of file diff --git a/lib/src/main/java/de/edux/ml/knn/KnnClassifier.java b/lib/src/main/java/de/edux/ml/knn/KnnClassifier.java index 9ea1a86..7315312 100644 --- a/lib/src/main/java/de/edux/ml/knn/KnnClassifier.java +++ b/lib/src/main/java/de/edux/ml/knn/KnnClassifier.java @@ -1,66 +1,93 @@ package de.edux.ml.knn; +import de.edux.api.Classifier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.*; -import java.util.stream.Collectors; +import java.util.Arrays; +import java.util.PriorityQueue; -public class KnnClassifier { - private static final Logger LOG = LoggerFactory.getLogger(KnnClassifier.class); - private final int k; - private List trainingPoints; +public class KnnClassifier implements Classifier { + Logger LOG = LoggerFactory.getLogger(KnnClassifier.class); + private double[][] trainFeatures; + private double[][] trainLabels; + private int k; + private static final double EPSILON = 1e-10; - public KnnClassifier(int k, List trainingPoints) { + public KnnClassifier(int k) { + if (k <= 0) { + throw new IllegalArgumentException("k must be a positive integer"); + } this.k = k; - this.trainingPoints = trainingPoints; } - private double distance(ILabeledPoint a, ILabeledPoint b) { - double[] aFeatures = a.getFeatures(); - double[] bFeatures = b.getFeatures(); - - if (aFeatures.length != bFeatures.length) { - throw new IllegalArgumentException("Both points must have the same number of features"); + @Override + public boolean train(double[][] features, double[][] labels) { + if (features.length == 0 || features.length != labels.length) { + return false; } + this.trainFeatures = features; + this.trainLabels = labels; + return true; + } - double sum = 0.0; - for (int i = 0; i < aFeatures.length; i++) { - sum += Math.pow(aFeatures[i] - bFeatures[i], 2); + @Override + public double evaluate(double[][] testInputs, double[][] testTargets) { + LOG.info("Evaluating..."); + int correct = 0; + for (int i = 0; i < testInputs.length; i++) { + double[] prediction = predict(testInputs[i]); + if (Arrays.equals(prediction, testTargets[i])) { + correct++; + } } - return Math.sqrt(sum); + double accuracy = (double) correct / testInputs.length; + LOG.info("KNN - Accuracy: " + accuracy * 100 + "%"); + return accuracy; } - public String classify(ILabeledPoint unknown) { - PriorityQueue nearestNeighbors = new PriorityQueue<>( - Comparator.comparingDouble(p -> distance(p, unknown)) - ); + @Override + public double[] predict(double[] feature) { + PriorityQueue pq = new PriorityQueue<>((a, b) -> Double.compare(b.distance, a.distance)); + for (int i = 0; i < trainFeatures.length; i++) { + double distance = calculateDistance(trainFeatures[i], feature); + pq.offer(new Neighbor(distance, trainLabels[i])); + if (pq.size() > k) { + pq.poll(); + } + } - for (ILabeledPoint p : trainingPoints) { - if (nearestNeighbors.size() < k) { - nearestNeighbors.add(p); - } else if (distance(p, unknown) < distance(nearestNeighbors.peek(), unknown)) { - nearestNeighbors.poll(); - nearestNeighbors.add(p); + double[] aggregatedLabel = new double[trainLabels[0].length]; + double totalWeight = 0; + for (Neighbor neighbor : pq) { + double weight = 1 / (neighbor.distance + EPSILON); + for (int i = 0; i < aggregatedLabel.length; i++) { + aggregatedLabel[i] += neighbor.label[i] * weight; } + totalWeight += weight; } - Map labelCounts = nearestNeighbors.stream() - .collect(Collectors.groupingBy(ILabeledPoint::getLabel, Collectors.counting())); - return Collections.max(labelCounts.entrySet(), Map.Entry.comparingByValue()).getKey(); + for (int i = 0; i < aggregatedLabel.length; i++) { + aggregatedLabel[i] /= totalWeight; + } + return aggregatedLabel; } - public double evaluate(List testPoints) { - int correct = 0; - - for (ILabeledPoint p : testPoints) { - if (classify(p).equals(p.getLabel())) { - correct++; - } + private double calculateDistance(double[] a, double[] b) { + double sum = 0; + for (int i = 0; i < a.length; i++) { + sum += Math.pow(a[i] - b[i], 2); } + return Math.sqrt(sum); + } - double accuracy = (double) correct / testPoints.size()*100; - LOG.info("Accuracy: {}%", accuracy); - return accuracy; + private static class Neighbor { + private double distance; + private double[] label; + + public Neighbor(double distance, double[] label) { + this.distance = distance; + this.label = label; + } } } diff --git a/lib/src/main/java/de/edux/ml/knn/KnnPoint.java b/lib/src/main/java/de/edux/ml/knn/KnnPoint.java deleted file mode 100644 index 3a4eed2..0000000 --- a/lib/src/main/java/de/edux/ml/knn/KnnPoint.java +++ /dev/null @@ -1,22 +0,0 @@ -package de.edux.ml.knn; - -public class KnnPoint implements ILabeledPoint{ - - private final String label; - private final double[] features; - - public KnnPoint(double[] features, String label) { - this.features = features; - this.label = label; - } - - @Override - public double[] getFeatures() { - return features; - } - - @Override - public String getLabel() { - return label; - } -} diff --git a/lib/src/main/java/de/edux/ml/nn/Neuron.java b/lib/src/main/java/de/edux/ml/nn/Neuron.java index ecb34b7..b77516e 100644 --- a/lib/src/main/java/de/edux/ml/nn/Neuron.java +++ b/lib/src/main/java/de/edux/ml/nn/Neuron.java @@ -4,6 +4,7 @@ import de.edux.functions.initialization.Initialization; public class Neuron { + private final Initialization initialization; private double[] weights; private double bias; private final ActivationFunction activationFunction; @@ -11,11 +12,16 @@ public class Neuron { public Neuron(int inputSize, ActivationFunction activationFunction, Initialization initialization) { this.weights = new double[inputSize]; this.activationFunction = activationFunction; + this.initialization = initialization; this.bias = initialization.weightInitialization(inputSize, new double[1])[0]; this.weights = initialization.weightInitialization(inputSize, weights); } + public Initialization getInitialization() { + return initialization; + } + public double calculateOutput(double[] input) { double output = bias; for (int i = 0; i < input.length; i++) { @@ -49,4 +55,12 @@ public double getBias() { public ActivationFunction getActivationFunction() { return activationFunction; } + + public void setWeights(double[] weights) { + this.weights = weights; + } + + public void setBias(double bias) { + this.bias = bias; + } } diff --git a/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java b/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java index caada09..82fb6a6 100644 --- a/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java +++ b/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java @@ -1,5 +1,6 @@ package de.edux.ml.nn.network; +import de.edux.api.Classifier; import de.edux.functions.initialization.Initialization; import de.edux.ml.nn.Neuron; import de.edux.functions.activation.ActivationFunction; @@ -10,26 +11,24 @@ import java.util.ArrayList; import java.util.List; -public class MultilayerPerceptron { +public class MultilayerPerceptron implements Classifier { private static final Logger LOG = LoggerFactory.getLogger(MultilayerPerceptron.class); - private final double[][] inputs; - private final double[][] targets; private final NetworkConfiguration config; private final ActivationFunction hiddenLayerActivationFunction; private final ActivationFunction outputLayerActivationFunction; - private final double[][] testInputs; - private final double[][] testTargets; - private final List hiddenLayers; - private final Neuron[] outputLayer; + private List hiddenLayers; + private Neuron[] outputLayer; + private final double[][] testFeatures; + private final double[][] testLabels; private double bestAccuracy; + private ArrayList bestHiddenLayers; + private Neuron[] bestOutputLayer; - public MultilayerPerceptron(double[][] inputs, double[][] targets, double[][] testInputs, double[][] testTargets, NetworkConfiguration config) { - this.inputs = inputs; - this.targets = targets; - this.testInputs = testInputs; - this.testTargets = testTargets; + public MultilayerPerceptron(NetworkConfiguration config, double[][] testFeatures, double[][] testLabels) { this.config = config; + this.testFeatures = testFeatures; + this.testLabels = testLabels; hiddenLayerActivationFunction = config.hiddenLayerActivationFunction(); outputLayerActivationFunction = config.outputLayerActivationFunction(); @@ -55,7 +54,6 @@ public MultilayerPerceptron(double[][] inputs, double[][] targets, double[][] te private double[] feedforward(double[] input) { double[] currentInput = input; - // Pass input through all hidden layers for (Neuron[] layer : hiddenLayers) { double[] hiddenOutputs = new double[layer.length]; for (int i = 0; i < layer.length; i++) { @@ -64,7 +62,6 @@ private double[] feedforward(double[] input) { currentInput = hiddenOutputs; } - // Pass input through output layer double[] output = new double[config.outputSize()]; for (int i = 0; i < config.outputSize(); i++) { output[i] = outputLayer[i].calculateOutput(currentInput); @@ -73,16 +70,20 @@ private double[] feedforward(double[] input) { return outputLayerActivationFunction.calculateActivation(output); } - public void train() { - bestAccuracy = 0; + @Override + public boolean train(double[][] features, double[][] labels) { + bestAccuracy = 0; + int epochsWithoutImprovement = 0; + final int PATIENCE = 10; + for (int epoch = 0; epoch < config.epochs(); epoch++) { - for (int i = 0; i < inputs.length; i++) { - double[] output = feedforward(inputs[i]); + for (int i = 0; i < features.length; i++) { + double[] output = feedforward(features[i]); - // Calculate error signals double[] output_error_signal = new double[config.outputSize()]; - for (int j = 0; j < config.outputSize(); j++) - output_error_signal[j] = targets[i][j] - output[j]; + for (int j = 0; j < config.outputSize(); j++) { + output_error_signal[j] = labels[i][j] - output[j]; + } List hidden_error_signals = new ArrayList<>(); for (int j = hiddenLayers.size() - 1; j >= 0; j--) { @@ -96,31 +97,63 @@ public void train() { output_error_signal = hidden_error_signal; } - updateWeights(i, output_error_signal, hidden_error_signals); + + updateWeights(i, output_error_signal, hidden_error_signals, features); } - if (epoch % 10 == 0) { - double accuracy = evaluate(testInputs, testTargets) * 100; - if (accuracy == 100) { - LOG.info("Stop training at: {}%", String.format("%.2f", accuracy)); - return; - } + double accuracy = evaluate(testFeatures, testLabels); + LOG.info("Epoch: {} - Accuracy: {}%", epoch, String.format("%.2f", accuracy * 100)); - if (accuracy > bestAccuracy) { - bestAccuracy = accuracy; - LOG.info("Best Accuracy: {}%", String.format("%.2f", bestAccuracy)); - } - // if accuracy 20% lower than best accuracy, stop training - if (bestAccuracy - accuracy > 20) { - LOG.info("Local Minama found, stop training"); - return; + if (accuracy > bestAccuracy) { + bestAccuracy = accuracy; + epochsWithoutImprovement = 0; + saveBestModel(hiddenLayers, outputLayer); + } else { + epochsWithoutImprovement++; + } + + if (epochsWithoutImprovement >= PATIENCE) { + LOG.info("Early stopping: Stopping training as the model has not improved in the last {} epochs.", PATIENCE); + loadBestModel(); + LOG.info("Best accuracy after restoring best MLP model: {}%", String.format("%.2f", bestAccuracy * 100)); + break; + } + } + return true; + } + + private void loadBestModel() { + this.hiddenLayers = this.bestHiddenLayers; + this.outputLayer = this.bestOutputLayer; + } + + private void saveBestModel(List hiddenLayers, Neuron[] outputLayer) { + this.bestHiddenLayers = new ArrayList<>(); + this.bestOutputLayer = new Neuron[outputLayer.length]; + for (int i = 0; i < hiddenLayers.size(); i++) { + Neuron[] layer = hiddenLayers.get(i); + Neuron[] newLayer = new Neuron[layer.length]; + for (int j = 0; j < layer.length; j++) { + newLayer[j] = new Neuron(layer[j].getWeights().length, layer[j].getActivationFunction(), layer[j].getInitialization()); + newLayer[j].setBias(layer[j].getBias()); + for (int k = 0; k < layer[j].getWeights().length; k++) { + newLayer[j].getWeights()[k] = layer[j].getWeight(k); } } + this.bestHiddenLayers.add(newLayer); } + for (int i = 0; i < outputLayer.length; i++) { + this.bestOutputLayer[i] = new Neuron(outputLayer[i].getWeights().length, outputLayer[i].getActivationFunction(), outputLayer[i].getInitialization()); + this.bestOutputLayer[i].setBias(outputLayer[i].getBias()); + for (int j = 0; j < outputLayer[i].getWeights().length; j++) { + this.bestOutputLayer[i].getWeights()[j] = outputLayer[i].getWeight(j); + } + } + } - private void updateWeights(int i, double[] output_error_signal, List hidden_error_signals) { - double[] currentInput = inputs[i]; + private void updateWeights(int i, double[] output_error_signal, List hidden_error_signals, double[][] features) { + double[] currentInput = features[i]; for (int j = 0; j < hiddenLayers.size(); j++) { Neuron[] layer = hiddenLayers.get(j); @@ -131,7 +164,7 @@ private void updateWeights(int i, double[] output_error_signal, List h } currentInput = new double[layer.length]; for (int k = 0; k < layer.length; k++) { - currentInput[k] = layer[k].calculateOutput(inputs[i]); + currentInput[k] = layer[k].calculateOutput(features[i]); } } @@ -141,7 +174,8 @@ private void updateWeights(int i, double[] output_error_signal, List h } } - private double evaluate(double[][] testInputs, double[][] testTargets) { + @Override + public double evaluate(double[][] testInputs, double[][] testTargets) { int correctCount = 0; for (int i = 0; i < testInputs.length; i++) { @@ -167,7 +201,4 @@ public double[] predict(double[] input) { return feedforward(input); } - public double getAccuracy() { - return bestAccuracy; - } } diff --git a/lib/src/main/java/de/edux/ml/nn/network/api/Dataset.java b/lib/src/main/java/de/edux/ml/nn/network/api/Dataset.java new file mode 100644 index 0000000..e4a159c --- /dev/null +++ b/lib/src/main/java/de/edux/ml/nn/network/api/Dataset.java @@ -0,0 +1,5 @@ +package de.edux.ml.nn.network.api; + +import java.util.List; + +public record Dataset(List trainData, List testData) {} diff --git a/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java b/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java index 49f559e..4db570c 100644 --- a/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java +++ b/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java @@ -1,5 +1,6 @@ package de.edux.ml.randomforest; +import de.edux.api.Classifier; import de.edux.ml.decisiontree.DecisionTree; import de.edux.ml.decisiontree.IDecisionTree; import org.slf4j.Logger; @@ -11,35 +12,76 @@ import java.util.Map; import java.util.concurrent.*; -public class RandomForest { +/** + *

RandomForest Classifier

+ * RandomForest is an ensemble learning method, which constructs a multitude of decision trees + * at training time and outputs the class that is the mode of the classes output by + * individual trees, or a mean prediction of the individual trees (regression). + *

+ * Note: Training and prediction are performed in a parallel manner using thread pooling. + * RandomForest handles the training of individual decision trees and their predictions, and + * determines the final prediction by voting (classification) or averaging (regression) the + * outputs of all the decision trees in the forest. RandomForest is particularly well suited + * for multiclass classification and regression on datasets with complex structures. + *

+ * Usage example: + *

+ * {@code
+ * RandomForest forest = new RandomForest();
+ * forest.train(numTrees, features, labels, maxDepth, minSamplesSplit, minSamplesLeaf,
+ *              maxLeafNodes, numberOfFeatures);
+ * double prediction = forest.predict(sampleFeatures);
+ * double accuracy = forest.evaluate(testFeatures, testLabels);
+ * }
+ * 
+ *

+ * Thread Safety: This class uses concurrent features but may not be entirely thread-safe + * and should be used with caution in a multithreaded environment. + *

+ * Use {@link #train(double[][], double[][])} to train the forest, + * {@link #predict(double[])} to predict a single sample, and {@link #evaluate(double[][], double[][])} + * to evaluate accuracy against a test set. + */ +public class RandomForest implements Classifier { private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class); - private final List trees = new ArrayList<>(); + private final List trees = new ArrayList<>(); private final ThreadLocalRandom threadLocalRandom = ThreadLocalRandom.current(); + private final int numTrees; + private final int maxDepth; + private final int minSamplesSplit; + private final int minSamplesLeaf; + private final int maxLeafNodes; + private final int numberOfFeatures; + + public RandomForest(int numTrees, int maxDepth, + int minSamplesSplit, + int minSamplesLeaf, + int maxLeafNodes, + int numberOfFeatures) { + this.numTrees = numTrees; + this.maxDepth = maxDepth; + this.minSamplesSplit = minSamplesSplit; + this.minSamplesLeaf = minSamplesLeaf; + this.maxLeafNodes = maxLeafNodes; + this.numberOfFeatures = numberOfFeatures; + } - public void train(int numTrees, - double[][] features, - double[][] labels, - int maxDepth, - int minSamplesSplit, - int minSamplesLeaf, - int maxLeafNodes, - int numberOfFeatures) { - + public boolean train(double[][] features, double[][] labels) { ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); - List> futures = new ArrayList<>(); + List> futures = new ArrayList<>(); for (int i = 0; i < numTrees; i++) { futures.add(executor.submit(() -> { - IDecisionTree tree = new DecisionTree(); + Classifier tree = new DecisionTree(maxDepth, minSamplesSplit, minSamplesLeaf, maxLeafNodes); Sample subsetSample = getRandomSubset(numberOfFeatures, features, labels); - tree.train(subsetSample.featureSamples(), subsetSample.labelSamples(), maxDepth, minSamplesSplit, minSamplesLeaf, maxLeafNodes); + tree.train(subsetSample.featureSamples(), subsetSample.labelSamples()); return tree; })); } - for (Future future : futures) { + for (Future future : futures) { try { trees.add(future.get()); } catch (ExecutionException | InterruptedException e) { @@ -56,6 +98,7 @@ public void train(int numTrees, executor.shutdownNow(); Thread.currentThread().interrupt(); } + return true; } private Sample getRandomSubset(int numberOfFeatures, double[][] features, double[][] labels) { @@ -74,19 +117,22 @@ private Sample getRandomSubset(int numberOfFeatures, double[][] features, double } - public double predict(double[] feature) { + @Override + public double[] predict(double[] feature) { ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); - List> futures = new ArrayList<>(); + List> futures = new ArrayList<>(); - for (IDecisionTree tree : trees) { + for (Classifier tree : trees) { futures.add(executor.submit(() -> tree.predict(feature))); } Map voteMap = new HashMap<>(); - for (Future future : futures) { + for (Future future : futures) { try { - double prediction = future.get(); - voteMap.merge(prediction, 1L, Long::sum); + double[] prediction = future.get(); + /* voteMap.merge(prediction, 1L, Long::sum);*/ + double label = getIndexOfHighestValue(prediction); + voteMap.merge(label, 1L, Long::sum); } catch (InterruptedException | ExecutionException e) { LOG.error("Failed to retrieve prediction from future task. Thread: " + Thread.currentThread().getName(), e); @@ -102,26 +148,33 @@ public double predict(double[] feature) { executor.shutdownNow(); Thread.currentThread().interrupt(); } - - return voteMap.entrySet().stream() + double predictionLabel = voteMap.entrySet().stream() .max(Map.Entry.comparingByValue()) - .map(Map.Entry::getKey) - .orElseThrow(() -> new RuntimeException("Failed to find the most common prediction")); - } + .get() + .getKey(); + double[] prediction = new double[trees.get(0).predict(feature).length]; + prediction[(int) predictionLabel] = 1; + return prediction; + } + @Override public double evaluate(double[][] features, double[][] labels) { int correctPredictions = 0; for (int i = 0; i < features.length; i++) { - double predictedLabel = predict(features[i]); + double[] predictedLabelProbabilities = predict(features[i]); + double predictedLabel = getIndexOfHighestValue(predictedLabelProbabilities); double actualLabel = getIndexOfHighestValue(labels[i]); if (predictedLabel == actualLabel) { correctPredictions++; } } - return (double) correctPredictions / features.length; + double accuracy = (double) correctPredictions / features.length; + LOG.info("RandomForest - Accuracy: " + String.format("%.4f", accuracy * 100) + "%"); + return accuracy; } + private double getIndexOfHighestValue(double[] labels) { int maxIndex = 0; double maxValue = labels[0]; diff --git a/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java b/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java index 3a23472..e4ef6fe 100644 --- a/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java +++ b/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java @@ -1,12 +1,13 @@ package de.edux.ml.svm; +import de.edux.api.Classifier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.*; import java.util.stream.Collectors; -public class SupportVectorMachine implements ISupportVectorMachine { +public class SupportVectorMachine implements Classifier { private static final Logger LOG = LoggerFactory.getLogger(SupportVectorMachine.class); @@ -20,9 +21,10 @@ public SupportVectorMachine(SVMKernel kernel, double c) { } @Override - public void train(double[][] features, int[] labels) { + public boolean train(double[][] features, double[][] labels) { + var oneDLabels = convert2DLabelArrayTo1DLabelArray(labels); // Identify unique class labels - Set uniqueLabels = Arrays.stream(labels).boxed().collect(Collectors.toSet()); + Set uniqueLabels = Arrays.stream(oneDLabels).boxed().collect(Collectors.toSet()); Integer[] uniqueLabelsArray = uniqueLabels.toArray(new Integer[0]); // In One-vs-One, you should consider every possible pair of classes @@ -35,10 +37,10 @@ public void train(double[][] features, int[] labels) { List list = new ArrayList<>(); List pairLabelsList = new ArrayList<>(); for (int k = 0; k < features.length; k++) { - if (labels[k] == uniqueLabelsArray[i] || labels[k] == uniqueLabelsArray[j]) { + if (oneDLabels[k] == uniqueLabelsArray[i] || oneDLabels[k] == uniqueLabelsArray[j]) { list.add(features[k]); // Ensure that the sign of the label matches our assumption - pairLabelsList.add(labels[k] == uniqueLabelsArray[i] ? 1 : -1); + pairLabelsList.add(oneDLabels[k] == uniqueLabelsArray[i] ? 1 : -1); } } double[][] pairFeatures = list.toArray(new double[0][]); @@ -49,10 +51,11 @@ public void train(double[][] features, int[] labels) { models.put(key, model); } } + return true; } @Override - public int predict(double[] features) { + public double[] predict(double[] features) { Map voteCount = new HashMap<>(); // In One-vs-One, you look at the prediction of each model and count the votes for (Map.Entry entry : models.entrySet()) { @@ -66,24 +69,36 @@ public int predict(double[] features) { } // The final prediction is the class with the most votes - return voteCount.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey(); + int prediction = voteCount.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey(); + double[] result = new double[models.size()]; + result[prediction - 1] = 1; + return result; } - - - @Override - public double evaluate(double[][] features, int[] labels) { + public double evaluate(double[][] features, double[][] labels) { int correct = 0; for (int i = 0; i < features.length; i++) { - if (predict(features[i]) == labels[i]) { + boolean match = Arrays.equals(predict(features[i]), labels[i]); + if (match) { correct++; } } double accuracy = (double) correct / features.length; - LOG.info("Accuracy: " + String.format("%.4f", accuracy * 100) + "%"); + LOG.info("SVM - Accuracy: " + String.format("%.4f", accuracy * 100) + "%"); return accuracy; } + private int[] convert2DLabelArrayTo1DLabelArray(double[][] labels) { + int[] decisionTreeTrainLabels = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + for (int j = 0; j < labels[i].length; j++) { + if (labels[i][j] == 1) { + decisionTreeTrainLabels[i] = (j+1); + } + } + } + return decisionTreeTrainLabels; + } } diff --git a/lib/src/test/java/de/edux/data/provider/DataUtilTest.java b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java similarity index 58% rename from lib/src/test/java/de/edux/data/provider/DataUtilTest.java rename to lib/src/test/java/de/edux/data/provider/DataProcessorTest.java index df43465..f0c36f3 100644 --- a/lib/src/test/java/de/edux/data/provider/DataUtilTest.java +++ b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java @@ -1,6 +1,7 @@ package de.edux.data.provider; import de.edux.data.reader.CSVIDataReader; +import de.edux.ml.nn.network.api.Dataset; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -21,16 +22,16 @@ @ExtendWith(MockitoExtension.class) -class DataUtilTest { +class DataProcessorTest { @InjectMocks - private DataUtil dataUtil = getDummyDataUtil(); + private DataProcessor dataProcessor = getDummyDataUtil(); @Mock private CSVIDataReader csvDataReader; @BeforeEach void setUp() { - dataUtil = getDummyDataUtil(); + dataProcessor = getDummyDataUtil(); } @Test @@ -38,10 +39,10 @@ void testSplitWithValidRatio() { List dataset = Arrays.asList("A", "B", "C", "D", "E"); double trainTestSplitRatio = 0.6; - List> result = dataUtil.split(dataset, trainTestSplitRatio); + Dataset result = dataProcessor.split(dataset, trainTestSplitRatio); - assertEquals(3, result.get(0).size(), "Train dataset size should be 3"); - assertEquals(2, result.get(1).size(), "Test dataset size should be 2"); + assertEquals(3, result.trainData().size(), "Train dataset size should be 3"); + assertEquals(2, result.testData().size(), "Test dataset size should be 2"); } @Test @@ -49,10 +50,10 @@ void testSplitWithZeroRatio() { List dataset = Arrays.asList("A", "B", "C", "D", "E"); double trainTestSplitRatio = 0.0; - List> result = dataUtil.split(dataset, trainTestSplitRatio); + Dataset result = dataProcessor.split(dataset, trainTestSplitRatio); - assertEquals(0, result.get(0).size(), "Train dataset size should be 0"); - assertEquals(5, result.get(1).size(), "Test dataset size should be 5"); + assertEquals(0, result.trainData().size(), "Train dataset size should be 0"); + assertEquals(5, result.testData().size(), "Test dataset size should be 5"); } @Test @@ -60,10 +61,10 @@ void testSplitWithFullRatio() { List dataset = Arrays.asList("A", "B", "C", "D", "E"); double trainTestSplitRatio = 1.0; - List> result = dataUtil.split(dataset, trainTestSplitRatio); + Dataset result = dataProcessor.split(dataset, trainTestSplitRatio); - assertEquals(5, result.get(0).size(), "Train dataset size should be 5"); - assertEquals(0, result.get(1).size(), "Test dataset size should be 0"); + assertEquals(5, result.trainData().size(), "Train dataset size should be 5"); + assertEquals(0, result.testData().size(), "Test dataset size should be 0"); } @Test @@ -71,7 +72,7 @@ void testSplitWithInvalidNegativeRatio() { List dataset = Arrays.asList("A", "B", "C", "D", "E"); double trainTestSplitRatio = -0.1; - assertThrows(IllegalArgumentException.class, () -> dataUtil.split(dataset, trainTestSplitRatio)); + assertThrows(IllegalArgumentException.class, () -> dataProcessor.split(dataset, trainTestSplitRatio)); } @Test @@ -79,7 +80,7 @@ void testSplitWithInvalidAboveOneRatio() { List dataset = Arrays.asList("A", "B", "C", "D", "E"); double trainTestSplitRatio = 1.1; - assertThrows(IllegalArgumentException.class, () -> dataUtil.split(dataset, trainTestSplitRatio)); + assertThrows(IllegalArgumentException.class, () -> dataProcessor.split(dataset, trainTestSplitRatio)); } @Test @@ -94,13 +95,13 @@ void testLoadTDataSetWithoutNormalizationAndShuffling() { when(csvDataReader.readFile(any(), anyChar())).thenReturn(csvLine); - List result = dataUtil.loadTDataSet(dummyFile, separator, false, false, false); + List result = dataProcessor.loadDataSetFromCSV(dummyFile, separator, false, false, false); assertEquals(2, result.size(), "Dataset size should be 2"); } - private DataUtil getDummyDataUtil() { - return new DataUtil<>(csvDataReader) { + private DataProcessor getDummyDataUtil() { + return new DataProcessor<>(csvDataReader) { @Override public void normalize(List dataset) { @@ -121,6 +122,31 @@ public double[][] getInputs(List dataset) { public double[][] getTargets(List dataset) { return new double[0][]; } + + @Override + public String getDatasetDescription() { + return null; + } + + @Override + public double[][] getTrainFeatures() { + return new double[0][]; + } + + @Override + public double[][] getTrainLabels() { + return new double[0][]; + } + + @Override + public double[][] getTestLabels() { + return new double[0][]; + } + + @Override + public double[][] getTestFeatures() { + return new double[0][]; + } }; } diff --git a/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java b/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java index b46999a..0584568 100644 --- a/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java +++ b/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java @@ -3,7 +3,7 @@ import java.util.ArrayList; import java.util.List; -public class SeabornDataProcessor extends DataUtil { +public class SeabornDataProcessor extends DataProcessor { @Override public void normalize(List penguins) { double maxBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).max().orElse(1); @@ -116,4 +116,67 @@ public double[][] getTargets(List dataset) { return targets; } + + @Override + public String getDatasetDescription() { + return "Seaborn penguins dataset"; + } + + @Override + public double[][] getTrainFeatures() { + return featuresOf(getSplitedDataset().trainData()); + } + + @Override + public double[][] getTrainLabels() { + return labelsOf(getSplitedDataset().trainData()); + } + + @Override + public double[][] getTestFeatures() { + return featuresOf(getSplitedDataset().testData()); + } + + @Override + public double[][] getTestLabels() { + return labelsOf(getSplitedDataset().testData()); + } + + private double[][] featuresOf(List data) { + double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften + + for (int i = 0; i < data.size(); i++) { + Penguin p = data.get(i); + features[i][0] = p.billLengthMm(); + features[i][1] = p.billDepthMm(); + features[i][2] = p.flipperLengthMm(); + features[i][3] = p.bodyMassG(); + } + + return features; + } + + private double[][] labelsOf(List data) { + double[][] labels = new double[data.size()][3]; // 3 Pinguinarten + + for (int i = 0; i < data.size(); i++) { + Penguin p = data.get(i); + switch (p.species().toLowerCase()) { + case "adelie": + labels[i] = new double[]{1.0, 0.0, 0.0}; + break; + case "chinstrap": + labels[i] = new double[]{0.0, 1.0, 0.0}; + break; + case "gentoo": + labels[i] = new double[]{0.0, 0.0, 1.0}; + break; + default: + throw new IllegalArgumentException("Unbekannte Pinguinart: " + p.species()); + } + } + + return labels; + } + } \ No newline at end of file diff --git a/lib/src/test/java/de/edux/ml/RandomForestTest.java b/lib/src/test/java/de/edux/ml/RandomForestTest.java index d59599c..a2a6c4c 100644 --- a/lib/src/test/java/de/edux/ml/RandomForestTest.java +++ b/lib/src/test/java/de/edux/ml/RandomForestTest.java @@ -1,5 +1,6 @@ package de.edux.ml; +import de.edux.api.Classifier; import de.edux.data.provider.Penguin; import de.edux.data.provider.SeabornDataProcessor; import de.edux.data.provider.SeabornProvider; @@ -28,9 +29,9 @@ static void setup() { } File csvFile = new File(url.getPath()); var seabornDataProcessor = new SeabornDataProcessor(); - var dataset = seabornDataProcessor.loadTDataSet(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - List> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); - seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1)); + var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + var splitedDataset = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); + seabornProvider = new SeabornProvider(dataset, splitedDataset.trainData(), splitedDataset.testData()); } @Test void train() { @@ -46,14 +47,14 @@ void train() { assertTrue(testLabels.length > 0); int numberOfTrees = 100; - int maxDepth = 24; + int maxDepth = 8; int minSampleSize = 2; int minSamplesLeaf = 1; - int maxLeafNodes = 12; + int maxLeafNodes = 2; int numFeatures = (int) Math.sqrt(features.length)*3; - RandomForest randomForest = new RandomForest(); - randomForest.train(numberOfTrees, features, labels, maxDepth, minSampleSize, minSamplesLeaf, maxLeafNodes,numFeatures); + Classifier randomForest = new RandomForest( numberOfTrees, maxDepth, minSampleSize, minSamplesLeaf, maxLeafNodes,numFeatures); + randomForest.train( features, labels); double accuracy = randomForest.evaluate(testFeatures, testLabels); System.out.println(accuracy); assertTrue(accuracy>0.7); diff --git a/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java b/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java index 99fcc4b..537f19c 100644 --- a/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java +++ b/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java @@ -3,6 +3,7 @@ import de.edux.data.provider.Penguin; import de.edux.data.provider.SeabornDataProcessor; import de.edux.data.provider.SeabornProvider; +import de.edux.ml.nn.network.api.Dataset; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; @@ -28,9 +29,9 @@ static void setup() { } File csvFile = new File(url.getPath()); var seabornDataProcessor = new SeabornDataProcessor(); - var dataset = seabornDataProcessor.loadTDataSet(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - List> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); - seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1)); + var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + Dataset splitedDataset = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); + seabornProvider = new SeabornProvider(dataset, splitedDataset.trainData(), splitedDataset.testData()); } @RepeatedTest(5) @@ -45,13 +46,13 @@ void train() { assertTrue(labels.length > 0); assertTrue(testFeatures.length > 0); assertTrue(testLabels.length > 0); - - IDecisionTree decisionTree = new DecisionTree(); - int maxDepth = 10; + int maxDepth = 12; int minSampleSplit = 2; int minSampleLeaf = 1; int maxLeafNodes = 8; - decisionTree.train(features, labels, maxDepth, minSampleSplit, minSampleLeaf, maxLeafNodes); + DecisionTree decisionTree = new DecisionTree(maxDepth, minSampleSplit, minSampleLeaf, maxLeafNodes); + + decisionTree.train(features, labels); double accuracy = decisionTree.evaluate(testFeatures, testLabels); assertTrue(accuracy>0.7); } diff --git a/lib/src/test/java/de/edux/ml/knn/KnnClassifierTest.java b/lib/src/test/java/de/edux/ml/knn/KnnClassifierTest.java deleted file mode 100644 index 56a9d29..0000000 --- a/lib/src/test/java/de/edux/ml/knn/KnnClassifierTest.java +++ /dev/null @@ -1,52 +0,0 @@ -package de.edux.ml.knn; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class KnnClassifierTest { - - private KnnClassifier classifier; - private List trainingPoints; - private List testPoints; - - @BeforeEach - public void setup() { - // Erstellen Sie einige Trainingspunkte - ILabeledPoint trainingPoint1 = new KnnPoint(new double[]{1.0, 1.2, 1.4}, "Label1"); - ILabeledPoint trainingPoint2 = new KnnPoint(new double[]{3.0, 3.1, 3.2}, "Label2"); - ILabeledPoint trainingPoint3 = new KnnPoint(new double[]{7.0, 7.0, 7.0}, "Label3"); - - trainingPoints = Arrays.asList(trainingPoint1, trainingPoint2, trainingPoint3); - - // Erstellen Sie einige Testpunkte - ILabeledPoint testPoint1 = new KnnPoint(new double[]{1.0, 1.0, 1.5}, "Label1"); - ILabeledPoint testPoint2 = new KnnPoint(new double[]{3.0, 3.5, 3.1}, "Label2"); - ILabeledPoint testPoint3 = new KnnPoint(new double[]{7.0, 7.0, 7.0}, "Label3"); - - testPoints = Arrays.asList(testPoint1, testPoint2, testPoint3); - - // Erstellen Sie den Klassifikator - classifier = new KnnClassifier(1, trainingPoints); - - classifier.evaluate(testPoints); - } - - @Test - public void testClassify() { - // Testen Sie die classify() Methode - assertEquals("Label1", classifier.classify(testPoints.get(0))); - assertEquals("Label2", classifier.classify(testPoints.get(1))); - assertEquals("Label3", classifier.classify(testPoints.get(2))); - } - - @Test - public void testEvaluate() { - // Testen Sie die evaluate() Methode - assertEquals(100, classifier.evaluate(testPoints), 0.01); - } -} diff --git a/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java b/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java index 94d8075..9d32ff6 100644 --- a/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java +++ b/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java @@ -33,9 +33,9 @@ void setUp() { } File csvFile = new File(url.getPath()); var seabornDataProcessor = new SeabornDataProcessor(); - var dataset = seabornDataProcessor.loadTDataSet(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - List> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); - seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1)); + var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + var trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); + seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.trainData(), trainTestSplittedList.testData()); } @@ -52,11 +52,11 @@ void shouldReachModelAccuracyAtLeast70() { assertTrue(testFeatures.length > 0); assertTrue(testLabels.length > 0); - NetworkConfiguration networkConfiguration = new NetworkConfiguration(features[0].length, List.of(24, 6), 3, 0.001, 10000, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); + NetworkConfiguration networkConfiguration = new NetworkConfiguration(features[0].length, List.of(128,256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); - MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(features, labels, testFeatures, testLabels, networkConfiguration); - multilayerPerceptron.train(); - double accuracy = multilayerPerceptron.getAccuracy(); + MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); + multilayerPerceptron.train(features, labels); + double accuracy = multilayerPerceptron.evaluate(testFeatures, testLabels); assertTrue(accuracy > 0.7); }