diff --git a/example/src/main/java/de/example/benchmark/Benchmark.java b/example/src/main/java/de/example/benchmark/Benchmark.java index 2f3fcc7..fb93600 100644 --- a/example/src/main/java/de/example/benchmark/Benchmark.java +++ b/example/src/main/java/de/example/benchmark/Benchmark.java @@ -90,7 +90,6 @@ private void run() { }); - //Sort and print results with numeration begin with best average accuracy System.out.println("Classifier performances (sorted by average accuracy):"); results.entrySet().stream() .map(entry -> { diff --git a/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java b/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java index 124b6f3..376767e 100644 --- a/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java +++ b/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java @@ -145,7 +145,7 @@ public double[][] getTestLabels() { } private double[][] featuresOf(List data) { - double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften + double[][] features = new double[data.size()][4]; for (int i = 0; i < data.size(); i++) { Penguin p = data.get(i); @@ -159,7 +159,7 @@ private double[][] featuresOf(List data) { } private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; // 3 Pinguinarten + double[][] labels = new double[data.size()][3]; for (int i = 0; i < data.size(); i++) { Penguin p = data.get(i); diff --git a/example/src/main/java/de/example/data/seaborn/SeabornProvider.java b/example/src/main/java/de/example/data/seaborn/SeabornProvider.java index 3c370d6..e2e5416 100644 --- a/example/src/main/java/de/example/data/seaborn/SeabornProvider.java +++ b/example/src/main/java/de/example/data/seaborn/SeabornProvider.java @@ -64,7 +64,7 @@ public double[][] getTestLabels() { } private double[][] featuresOf(List data) { - double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften + double[][] features = new double[data.size()][4]; for (int i = 0; i < data.size(); i++) { Penguin p = data.get(i); @@ -77,7 +77,7 @@ private double[][] featuresOf(List data) { return features; } private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; // 3 Pinguinarten + double[][] labels = new double[data.size()][3]; for (int i = 0; i < data.size(); i++) { Penguin p = data.get(i); diff --git a/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java b/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java index 3c856cd..f51e4fa 100644 --- a/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java +++ b/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java @@ -10,19 +10,16 @@ public class DecisionTreeExample { private static final boolean NORMALIZE = true; public static void main(String[] args) { - // Get IRIS dataset + var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.6); datasetProvider.printStatistics(); - //Get Features and Labels double[][] features = datasetProvider.getTrainFeatures(); double[][] labels = datasetProvider.getTrainLabels(); - // Train Decision Tree DecisionTree decisionTree = new DecisionTree(8, 2, 1, 4); decisionTree.train(features, labels); - // Evaluate Decision Tree double[][] testFeatures = datasetProvider.getTestFeatures(); double[][] testLabels = datasetProvider.getTestLabels(); decisionTree.evaluate(testFeatures, testLabels); diff --git a/example/src/main/java/de/example/knn/KnnIrisExample.java b/example/src/main/java/de/example/knn/KnnIrisExample.java index 0e5d536..e4d1850 100644 --- a/example/src/main/java/de/example/knn/KnnIrisExample.java +++ b/example/src/main/java/de/example/knn/KnnIrisExample.java @@ -27,7 +27,7 @@ public static void main(String[] args) { irisDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); Classifier knn = new KnnClassifier(2); - //Train and evaluate + knn.train(irisDataProcessor.getTrainFeatures(), irisDataProcessor.getTrainLabels()); knn.evaluate(irisDataProcessor.getTestFeatures(), irisDataProcessor.getTestLabels()); } diff --git a/example/src/main/java/de/example/knn/KnnSeabornExample.java b/example/src/main/java/de/example/knn/KnnSeabornExample.java index faf4f90..324d438 100644 --- a/example/src/main/java/de/example/knn/KnnSeabornExample.java +++ b/example/src/main/java/de/example/knn/KnnSeabornExample.java @@ -22,15 +22,14 @@ public class KnnSeabornExample { private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); public static void main(String[] args) { - //Load Data, shuffle, normalize, filter incomplete records out. var seabornDataProcessor = new SeabornDataProcessor(); List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - //Split dataset into train and test + Dataset dataset = seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); var seabornProvider = new SeabornProvider(data, dataset.trainData(), dataset.testData()); seabornProvider.printStatistics(); + Classifier knn = new KnnClassifier(2); - //Train and evaluate knn.train(seabornProvider.getTrainFeatures(), seabornProvider.getTrainLabels()); knn.evaluate(seabornProvider.getTestFeatures(), seabornProvider.getTestLabels()); } diff --git a/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java b/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java index 4e90069..93655f1 100644 --- a/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java +++ b/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java @@ -15,15 +15,12 @@ public class MultilayerPerceptronExample { private final static boolean NORMALIZE = true; public static void main(String[] args) { - // Get IRIS dataset var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.7); datasetProvider.printStatistics(); - //Get Features and Labels double[][] features = datasetProvider.getTrainFeatures(); double[][] labels = datasetProvider.getTrainLabels(); - //Get Test Features and Labels double[][] testFeatures = datasetProvider.getTestFeatures(); double[][] testLabels = datasetProvider.getTestLabels(); diff --git a/example/src/main/java/de/example/svm/SVMExample.java b/example/src/main/java/de/example/svm/SVMExample.java index f245fcd..7ef1984 100644 --- a/example/src/main/java/de/example/svm/SVMExample.java +++ b/example/src/main/java/de/example/svm/SVMExample.java @@ -14,14 +14,11 @@ public static void main(String[] args){ var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.6); datasetProvider.printStatistics(); - //Get Features and Labels var features = datasetProvider.getTrainFeatures(); - // 1 - SATOSA 2 - VERSICOLOR 3 - VIRGINICA var labels = datasetProvider.getTrainLabels(); - Classifier supportVectorMachine = new SupportVectorMachine(SVMKernel.LINEAR, 1); - //ONEvsONE Strategy + supportVectorMachine.train(features, labels); double[][] testFeatures = datasetProvider.getTestFeatures(); diff --git a/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java b/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java index 3e835c3..6b14067 100644 --- a/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java +++ b/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java @@ -14,12 +14,12 @@ public class CSVIDataReader implements IDataReader { public List readFile(File file, char separator) { - CSVParser csvParser = new CSVParserBuilder().withSeparator(separator).build(); // custom separator + CSVParser customCSVParser = new CSVParserBuilder().withSeparator(separator).build(); List result; try(CSVReader reader = new CSVReaderBuilder( new FileReader(file)) - .withCSVParser(csvParser) // custom CSV parser - .withSkipLines(1) // skip the first line, header info + .withCSVParser(customCSVParser) + .withSkipLines(1) .build()){ result = reader.readAll(); } catch (CsvException | IOException e) { diff --git a/lib/src/main/java/de/edux/functions/initialization/Initialization.java b/lib/src/main/java/de/edux/functions/initialization/Initialization.java index 91627e4..a174107 100644 --- a/lib/src/main/java/de/edux/functions/initialization/Initialization.java +++ b/lib/src/main/java/de/edux/functions/initialization/Initialization.java @@ -1,7 +1,6 @@ package de.edux.functions.initialization; public enum Initialization { - //Xavier and HE XAVIER { @Override public double[] weightInitialization(int inputSize, double[] weights) { diff --git a/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java b/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java index 626842c..3de021b 100644 --- a/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java +++ b/lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java @@ -88,8 +88,16 @@ public MultilayerPerceptron(NetworkConfiguration config, double[][] testFeatures } private double[] feedforward(double[] input) { - double[] currentInput = input; + double[] currentInput = passInputTroughAllHiddenLayers(input); + + double[] output = passInputTroughOutputLayer(currentInput); + + return outputLayerActivationFunction.calculateActivation(output); + } + + private double[] passInputTroughAllHiddenLayers(double[] input) { + double[] currentInput = input; for (Neuron[] layer : hiddenLayers) { double[] hiddenOutputs = new double[layer.length]; for (int i = 0; i < layer.length; i++) { @@ -97,13 +105,15 @@ private double[] feedforward(double[] input) { } currentInput = hiddenOutputs; } + return currentInput; + } + private double[] passInputTroughOutputLayer(double[] currentInput) { double[] output = new double[config.outputSize()]; for (int i = 0; i < config.outputSize(); i++) { output[i] = outputLayer[i].calculateOutput(currentInput); } - - return outputLayerActivationFunction.calculateActivation(output); + return output; } @Override diff --git a/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java b/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java index bf1e55f..e4c29b2 100644 --- a/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java +++ b/lib/src/main/java/de/edux/ml/randomforest/RandomForest.java @@ -129,7 +129,6 @@ public double[] predict(double[] feature) { for (Future future : futures) { try { double[] prediction = future.get(); - /* voteMap.merge(prediction, 1L, Long::sum);*/ double label = getIndexOfHighestValue(prediction); voteMap.merge(label, 1L, Long::sum); } catch (InterruptedException | ExecutionException e) { diff --git a/lib/src/main/java/de/edux/ml/svm/ISupportVectorMachine.java b/lib/src/main/java/de/edux/ml/svm/ISupportVectorMachine.java index 13e39cb..ef4b1da 100644 --- a/lib/src/main/java/de/edux/ml/svm/ISupportVectorMachine.java +++ b/lib/src/main/java/de/edux/ml/svm/ISupportVectorMachine.java @@ -4,9 +4,7 @@ public interface ISupportVectorMachine { void train(double[][] features, int[] labels); - // Methode zum Klassifizieren eines einzelnen Datenpunkts int predict(double[] features); - // Methode zum Evaluieren der Leistung der SVM auf einem Testdatensatz double evaluate(double[][] features, int[] labels); } diff --git a/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java b/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java index 1f6f04d..d6510d5 100644 --- a/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java +++ b/lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java @@ -42,30 +42,25 @@ public SupportVectorMachine(SVMKernel kernel, double c) { @Override public boolean train(double[][] features, double[][] labels) { var oneDLabels = convert2DLabelArrayTo1DLabelArray(labels); - // Identify unique class labels Set uniqueLabels = Arrays.stream(oneDLabels).boxed().collect(Collectors.toSet()); Integer[] uniqueLabelsArray = uniqueLabels.toArray(new Integer[0]); - // In One-vs-One, you should consider every possible pair of classes for (int i = 0; i < uniqueLabelsArray.length; i++) { for (int j = i + 1; j < uniqueLabelsArray.length; j++) { String key = uniqueLabelsArray[i] + "-" + uniqueLabelsArray[j]; SVMModel model = new SVMModel(kernel, c); - // Filter the features and labels for the two classes List list = new ArrayList<>(); List pairLabelsList = new ArrayList<>(); for (int k = 0; k < features.length; k++) { if (oneDLabels[k] == uniqueLabelsArray[i] || oneDLabels[k] == uniqueLabelsArray[j]) { list.add(features[k]); - // Ensure that the sign of the label matches our assumption pairLabelsList.add(oneDLabels[k] == uniqueLabelsArray[i] ? 1 : -1); } } double[][] pairFeatures = list.toArray(new double[0][]); int[] pairLabels = pairLabelsList.stream().mapToInt(Integer::intValue).toArray(); - // Train the model on the pair model.train(pairFeatures, pairLabels); models.put(key, model); } @@ -76,18 +71,16 @@ public boolean train(double[][] features, double[][] labels) { @Override public double[] predict(double[] features) { Map voteCount = new HashMap<>(); - // In One-vs-One, you look at the prediction of each model and count the votes + for (Map.Entry entry : models.entrySet()) { int prediction = entry.getValue().predict(features); - // map prediction back to actual class label String[] classes = entry.getKey().split("-"); int classLabel = (prediction == 1) ? Integer.parseInt(classes[0]) : Integer.parseInt(classes[1]); voteCount.put(classLabel, voteCount.getOrDefault(classLabel, 0) + 1); } - // The final prediction is the class with the most votes int prediction = voteCount.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey(); double[] result = new double[models.size()]; result[prediction - 1] = 1; diff --git a/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java index f0c36f3..a0b3af0 100644 --- a/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java +++ b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java @@ -105,7 +105,6 @@ private DataProcessor getDummyDataUtil() { @Override public void normalize(List dataset) { - // Mock normalize for the sake of testing } @Override diff --git a/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java b/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java index 0584568..bc08ee8 100644 --- a/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java +++ b/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java @@ -143,7 +143,7 @@ public double[][] getTestLabels() { } private double[][] featuresOf(List data) { - double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften + double[][] features = new double[data.size()][4]; for (int i = 0; i < data.size(); i++) { Penguin p = data.get(i); @@ -157,7 +157,7 @@ private double[][] featuresOf(List data) { } private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; // 3 Pinguinarten + double[][] labels = new double[data.size()][3]; for (int i = 0; i < data.size(); i++) { Penguin p = data.get(i); diff --git a/lib/src/test/java/de/edux/data/provider/SeabornProvider.java b/lib/src/test/java/de/edux/data/provider/SeabornProvider.java index ead4ccc..de98c57 100644 --- a/lib/src/test/java/de/edux/data/provider/SeabornProvider.java +++ b/lib/src/test/java/de/edux/data/provider/SeabornProvider.java @@ -50,7 +50,7 @@ public double[][] getTrainFeatures() { } private double[][] featuresOf(List data) { - double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften + double[][] features = new double[data.size()][4]; for (int i = 0; i < data.size(); i++) { Penguin p = data.get(i); @@ -70,7 +70,7 @@ public double[][] getTrainLabels() { } private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; // 3 Pinguinarten + double[][] labels = new double[data.size()][3]; for (int i = 0; i < data.size(); i++) { Penguin p = data.get(i); diff --git a/lib/src/test/java/de/edux/edux/activation/ActivationFunctionTest.java b/lib/src/test/java/de/edux/edux/activation/ActivationFunctionTest.java index 6bbcbc5..deb45b3 100644 --- a/lib/src/test/java/de/edux/edux/activation/ActivationFunctionTest.java +++ b/lib/src/test/java/de/edux/edux/activation/ActivationFunctionTest.java @@ -8,7 +8,7 @@ public class ActivationFunctionTest { - private static final double DELTA = 1e-6; // used to compare floating point numbers + private static final double DELTA = 1e-6; @Test public void testSigmoid() {