diff --git a/example/src/main/java/de/example/benchmark/Benchmark.java b/example/src/main/java/de/example/benchmark/Benchmark.java index fb93600..8ee0330 100644 --- a/example/src/main/java/de/example/benchmark/Benchmark.java +++ b/example/src/main/java/de/example/benchmark/Benchmark.java @@ -1,6 +1,8 @@ package de.example.benchmark; import de.edux.api.Classifier; +import de.edux.data.provider.DataProcessor; +import de.edux.data.reader.CSVIDataReader; import de.edux.functions.activation.ActivationFunction; import de.edux.functions.initialization.Initialization; import de.edux.functions.loss.LossFunction; @@ -11,8 +13,6 @@ import de.edux.ml.randomforest.RandomForest; import de.edux.ml.svm.SVMKernel; import de.edux.ml.svm.SupportVectorMachine; -import de.example.data.seaborn.Penguin; -import de.example.data.seaborn.SeabornDataProcessor; import java.io.File; import java.util.ArrayList; @@ -25,11 +25,10 @@ * Compare the performance of different classifiers */ public class Benchmark { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; - private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; - private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); + private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; + private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); + private static final boolean SKIP_HEAD = true; + private double[][] trainFeatures; private double[][] trainLabels; private double[][] testFeatures; @@ -49,7 +48,7 @@ private void run() { Classifier randomForest = new RandomForest(100, 10, 2, 1, 3, 60); Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 1); - networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128,256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); + networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128, 256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); Map classifiers = Map.of( "KNN", knn, @@ -67,7 +66,7 @@ private void run() { results.put("MLP", new ArrayList<>()); - IntStream.range(0, 50).forEach(i -> { + IntStream.range(0, 1).forEach(i -> { knn.train(trainFeatures, trainLabels); decisionTree.train(trainFeatures, trainLabels); randomForest.train(trainFeatures, trainLabels); @@ -126,16 +125,21 @@ private void updateMLP(double[][] testFeatures, double[][] testLabels) { } private void initFeaturesAndLabels() { - var seabornDataProcessor = new SeabornDataProcessor(); - List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); + var featureColumnIndices = new int[]{0, 1, 2, 3}; + var targetColumnIndex = 4; + + var dataProcessor = new DataProcessor(new CSVIDataReader()) + .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex) + .normalize() + .shuffle() + .split(TRAIN_TEST_SPLIT_RATIO); - trainFeatures = seabornDataProcessor.getTrainFeatures(); - trainLabels = seabornDataProcessor.getTrainLabels(); - testFeatures = seabornDataProcessor.getTestFeatures(); - testLabels = seabornDataProcessor.getTestLabels(); + trainFeatures = dataProcessor.getTrainFeatures(featureColumnIndices); + trainLabels = dataProcessor.getTrainLabels(targetColumnIndex); + testFeatures = dataProcessor.getTestFeatures(featureColumnIndices); + testLabels = dataProcessor.getTestLabels(targetColumnIndex); } } diff --git a/example/src/main/java/de/example/data/iris/Iris.java b/example/src/main/java/de/example/data/iris/Iris.java deleted file mode 100644 index a56feb1..0000000 --- a/example/src/main/java/de/example/data/iris/Iris.java +++ /dev/null @@ -1,33 +0,0 @@ -package de.example.data.iris; - - -public class Iris{ - public double sepalLength; - public double sepalWidth; - public double petalLength; - public double petalWidth; - public String variety; - - public Iris(double sepalLength, double sepalWidth, double petalLength, double petalWidth, String variety) { - this.sepalLength = sepalLength; - this.sepalWidth = sepalWidth; - this.petalLength = petalLength; - this.petalWidth = petalWidth; - this.variety = variety; - } - - @Override - public String toString() { - return "{sepalLength=" + sepalLength + - ", sepalWidth=" + sepalWidth + - ", petalLength=" + petalLength + - ", petalWidth=" + petalWidth + - ", variety='" + variety + '\'' + - '}'; - } - - public double[] getFeatures() { - return new double[]{sepalLength, sepalWidth, petalLength, petalWidth}; - } - -} diff --git a/example/src/main/java/de/example/data/iris/IrisDataProcessor.java b/example/src/main/java/de/example/data/iris/IrisDataProcessor.java deleted file mode 100644 index bbe9c12..0000000 --- a/example/src/main/java/de/example/data/iris/IrisDataProcessor.java +++ /dev/null @@ -1,133 +0,0 @@ -package de.example.data.iris; - -import de.edux.data.provider.DataProcessor; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; - -public class IrisDataProcessor extends DataProcessor { - private static final Logger LOG = LoggerFactory.getLogger(IrisDataProcessor.class); - private double[][] targets; - - @Override - public void normalize(List rowDataset) { - double minSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).min().getAsDouble(); - double maxSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).max().getAsDouble(); - double minSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).min().getAsDouble(); - double maxSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).max().getAsDouble(); - double minPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).min().getAsDouble(); - double maxPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).max().getAsDouble(); - double minPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).min().getAsDouble(); - double maxPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).max().getAsDouble(); - - for (Iris iris : rowDataset) { - iris.sepalLength = (iris.sepalLength - minSepalLength) / (maxSepalLength - minSepalLength); - iris.sepalWidth = (iris.sepalWidth - minSepalWidth) / (maxSepalWidth - minSepalWidth); - iris.petalLength = (iris.petalLength - minPetalLength) / (maxPetalLength - minPetalLength); - iris.petalWidth = (iris.petalWidth - minPetalWidth) / (maxPetalWidth - minPetalWidth); - } - } - - @Override - public Iris mapToDataRecord(String[] csvLine) { - return new Iris( - Double.parseDouble(csvLine[0]), - Double.parseDouble(csvLine[1]), - Double.parseDouble(csvLine[2]), - Double.parseDouble(csvLine[3]), - csvLine[4] - ); - } - - @Override - public double[][] getInputs(List dataset) { - double[][] inputs = new double[dataset.size()][4]; - for (int i = 0; i < dataset.size(); i++) { - inputs[i][0] = dataset.get(i).sepalLength; - inputs[i][1] = dataset.get(i).sepalWidth; - inputs[i][2] = dataset.get(i).petalLength; - inputs[i][3] = dataset.get(i).petalWidth; - } - return inputs; - } - - @Override - public double[][] getTargets(List dataset) { - targets = new double[dataset.size()][3]; - for (int i = 0; i < dataset.size(); i++) { - switch (dataset.get(i).variety) { - case "Setosa": - targets[i][0] = 1; - break; - case "Versicolor": - targets[i][1] = 1; - break; - case "Virginica": - targets[i][2] = 1; - break; - } - } - return targets; - } - - @Override - public double[][] getTrainFeatures() { - return featuresOf(getSplitedDataset().trainData()); - } - - @Override - public double[][] getTrainLabels() { - return labelsOf(getSplitedDataset().trainData()); - } - - @Override - public double[][] getTestLabels() { - return labelsOf(getSplitedDataset().testData()); - } - - @Override - public double[][] getTestFeatures() { - return featuresOf(getSplitedDataset().testData()); - } - - private double[][] featuresOf(List testData) { - double[][] features = new double[testData.size()][4]; - for (int i = 0; i < testData.size(); i++) { - features[i][0] = testData.get(i).sepalLength; - features[i][1] = testData.get(i).sepalWidth; - features[i][2] = testData.get(i).petalLength; - features[i][3] = testData.get(i).petalWidth; - } - return features; - } - - private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; - for (int i = 0; i < data.size(); i++) { - if (data.get(i).variety.equals("Setosa")) { - labels[i][0] = 1; - labels[i][1] = 0; - labels[i][2] = 0; - } - if (data.get(i).variety.equals("Versicolor")) { - labels[i][0] = 0; - labels[i][1] = 1; - labels[i][2] = 0; - } - if (data.get(i).variety.equals("Virginica")) { - labels[i][0] = 0; - labels[i][1] = 0; - labels[i][2] = 1; - } - } - return labels; - } - - @Override - public String getDatasetDescription() { - return "Iris dataset"; - } - - -} diff --git a/example/src/main/java/de/example/data/iris/IrisDataUtil.java b/example/src/main/java/de/example/data/iris/IrisDataUtil.java deleted file mode 100644 index 229190f..0000000 --- a/example/src/main/java/de/example/data/iris/IrisDataUtil.java +++ /dev/null @@ -1,93 +0,0 @@ -package de.example.data.iris; - -import de.edux.data.reader.CSVIDataReader; -import de.edux.data.reader.IDataReader; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -public class IrisDataUtil { - private static final Logger logger = LoggerFactory.getLogger(IrisDataUtil.class); - public static List loadIrisDataSet(boolean normalize, boolean shuffle) { - IDataReader dataReader = new CSVIDataReader(); - File csvFile = new File("example"+ File.separator + "datasets"+ File.separator + "iris" + File.separator + "iris.csv"); - List csvLines = dataReader.readFile(csvFile, ','); - List unmodifiableDataset = csvLines.stream().map(IrisDataUtil::mapToIris).toList(); - List dataset = new ArrayList<>(unmodifiableDataset); - if (normalize) { - normalize(dataset); - logger.info("Dataset normalized"); - } - if (shuffle) { - Collections.shuffle(dataset); - logger.info("Dataset shuffled"); - } - return dataset; - } - - private static void normalize(List rowDataset) { - double minSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).min().getAsDouble(); - double maxSepalLength = rowDataset.stream().mapToDouble(iris -> iris.sepalLength).max().getAsDouble(); - double minSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).min().getAsDouble(); - double maxSepalWidth = rowDataset.stream().mapToDouble(iris -> iris.sepalWidth).max().getAsDouble(); - double minPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).min().getAsDouble(); - double maxPetalLength = rowDataset.stream().mapToDouble(iris -> iris.petalLength).max().getAsDouble(); - double minPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).min().getAsDouble(); - double maxPetalWidth = rowDataset.stream().mapToDouble(iris -> iris.petalWidth).max().getAsDouble(); - - for (Iris iris : rowDataset) { - iris.sepalLength = (iris.sepalLength - minSepalLength) / (maxSepalLength - minSepalLength); - iris.sepalWidth = (iris.sepalWidth - minSepalWidth) / (maxSepalWidth - minSepalWidth); - iris.petalLength = (iris.petalLength - minPetalLength) / (maxPetalLength - minPetalLength); - iris.petalWidth = (iris.petalWidth - minPetalWidth) / (maxPetalWidth - minPetalWidth); - } - } - - private static Iris mapToIris(String[] csvLine) { - return new Iris( - Double.parseDouble(csvLine[0]), - Double.parseDouble(csvLine[1]), - Double.parseDouble(csvLine[2]), - Double.parseDouble(csvLine[3]), - csvLine[4] - ); - } - - public static double[][] getInputs(List dataset) { - double[][] inputs = new double[dataset.size()][4]; - for (int i = 0; i < dataset.size(); i++) { - inputs[i][0] = dataset.get(i).sepalLength; - inputs[i][1] = dataset.get(i).sepalWidth; - inputs[i][2] = dataset.get(i).petalLength; - inputs[i][3] = dataset.get(i).petalWidth; - } - return inputs; - } - - public static double[][] getTargets(List dataset) { - double[][] targets = new double[dataset.size()][3]; - for (int i = 0; i < dataset.size(); i++) { - switch (dataset.get(i).variety) { - case "Setosa": - targets[i][0] = 1; - break; - case "Versicolor": - targets[i][1] = 1; - break; - case "Virginica": - targets[i][2] = 1; - break; - } - } - return targets; - } - - public static List> split(List dataset, double v) { - int splitIndex = (int) (dataset.size() * v); - return List.of(dataset.subList(0, splitIndex), dataset.subList(splitIndex, dataset.size())); - } -} diff --git a/example/src/main/java/de/example/data/iris/IrisProvider.java b/example/src/main/java/de/example/data/iris/IrisProvider.java deleted file mode 100644 index cf1b0d0..0000000 --- a/example/src/main/java/de/example/data/iris/IrisProvider.java +++ /dev/null @@ -1,112 +0,0 @@ -package de.example.data.iris; - -import de.edux.data.provider.IDataProvider; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; - -import static de.example.data.iris.IrisDataUtil.loadIrisDataSet; - -public class IrisProvider implements IDataProvider { - private static final Logger LOG = LoggerFactory.getLogger(IrisProvider.class); - private final List dataset; - private final List> split; - private final List trainingData; - private final List testData; - - /** - * @param normalize Will normalize the data if true - * @param shuffle Will shuffle the data if true - * @param trainTestSplitRatio Ratio of train and test data e.g. 0.8 means 80% train data and 20% test data - */ - public IrisProvider(boolean normalize, boolean shuffle, double trainTestSplitRatio) { - dataset = loadIrisDataSet(normalize, shuffle); - split = IrisDataUtil.split(dataset, trainTestSplitRatio); - trainingData = split.get(0); - testData = split.get(1); - } - - @Override - public List getTrainData() { - return trainingData; - } - - @Override - public List getTestData() { - return testData; - } - - @Override - public void printStatistics() { - LOG.info("========================= Data Statistic =================="); - LOG.info("Total dataset size: " + dataset.size()); - LOG.info("Training dataset size: " + trainingData.size()); - LOG.info("Test data set size: " + testData.size()); - LOG.info("Classes: " + getTrainLabels()[0].length); - LOG.info("==========================================================="); - } - - @Override - public Iris getRandom(boolean equalDistribution) { - return dataset.get((int) (Math.random() * dataset.size())); - } - - @Override - public double[][] getTrainFeatures() { - return featuresOf(trainingData); - } - - @Override - public double[][] getTrainLabels() { - return labelsOf(trainingData); - } - - @Override - public double[][] getTestLabels() { - return labelsOf(testData); - } - - @Override - public double[][] getTestFeatures() { - return featuresOf(testData); - } - - private double[][] featuresOf(List testData) { - double[][] features = new double[testData.size()][4]; - for (int i = 0; i < testData.size(); i++) { - features[i][0] = testData.get(i).sepalLength; - features[i][1] = testData.get(i).sepalWidth; - features[i][2] = testData.get(i).petalLength; - features[i][3] = testData.get(i).petalWidth; - } - return features; - } - private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; - for (int i = 0; i < data.size(); i++) { - if (data.get(i).variety.equals("Setosa")) { - labels[i][0] = 1; - labels[i][1] = 0; - labels[i][2] = 0; - } - if (data.get(i).variety.equals("Versicolor")) { - labels[i][0] = 0; - labels[i][1] = 1; - labels[i][2] = 0; - } - if (data.get(i).variety.equals("Virginica")) { - labels[i][0] = 0; - labels[i][1] = 0; - labels[i][2] = 1; - } - } - return labels; - } - - - @Override - public String getDescription() { - return "Iris dataset"; - } -} diff --git a/example/src/main/java/de/example/data/seaborn/Penguin.java b/example/src/main/java/de/example/data/seaborn/Penguin.java deleted file mode 100644 index d2dd62f..0000000 --- a/example/src/main/java/de/example/data/seaborn/Penguin.java +++ /dev/null @@ -1,10 +0,0 @@ -package de.example.data.seaborn; - -/** - * SeaBorn penguin dto - */ -public record Penguin(String species, String island, double billLengthMm, double billDepthMm, int flipperLengthMm, int bodyMassG, String sex){ - public double[] getFeatures() { - return new double[]{billLengthMm, billDepthMm, flipperLengthMm, bodyMassG}; - } -} diff --git a/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java b/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java deleted file mode 100644 index 376767e..0000000 --- a/example/src/main/java/de/example/data/seaborn/SeabornDataProcessor.java +++ /dev/null @@ -1,184 +0,0 @@ -package de.example.data.seaborn; - -import de.edux.data.provider.DataProcessor; - -import java.util.ArrayList; -import java.util.List; - -public class SeabornDataProcessor extends DataProcessor { - @Override - public void normalize(List penguins) { - double maxBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).max().orElse(1); - double minBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).min().orElse(0); - - double maxBillDepth = penguins.stream().mapToDouble(Penguin::billDepthMm).max().orElse(1); - double minBillDepth = penguins.stream().mapToDouble(Penguin::billDepthMm).min().orElse(0); - - double maxFlipperLength = penguins.stream().mapToInt(Penguin::flipperLengthMm).max().orElse(1); - double minFlipperLength = penguins.stream().mapToInt(Penguin::flipperLengthMm).min().orElse(0); - - double maxBodyMass = penguins.stream().mapToInt(Penguin::bodyMassG).max().orElse(1); - double minBodyMass = penguins.stream().mapToInt(Penguin::bodyMassG).min().orElse(0); - - List normalizedPenguins = new ArrayList<>(); - for (Penguin p : penguins) { - double normalizedBillLength = (p.billLengthMm() - minBillLength) / (maxBillLength - minBillLength); - double normalizedBillDepth = (p.billDepthMm() - minBillDepth) / (maxBillDepth - minBillDepth); - double normalizedFlipperLength = (p.flipperLengthMm() - minFlipperLength) / (maxFlipperLength - minFlipperLength); - double normalizedBodyMass = (p.bodyMassG() - minBodyMass) / (maxBodyMass - minBodyMass); - - p = new Penguin(p.species(), p.island(), normalizedBillLength, normalizedBillDepth, (int) normalizedFlipperLength, (int) normalizedBodyMass, p.sex()); - - Penguin normalizedPenguin = new Penguin(p.species(), p.island(), normalizedBillLength, normalizedBillDepth, (int) normalizedFlipperLength, (int) normalizedBodyMass, p.sex()); - normalizedPenguins.add(normalizedPenguin); - } - penguins.clear(); - penguins.addAll(normalizedPenguins); - } - - @Override - public Penguin mapToDataRecord(String[] csvLine) { - if (csvLine.length != 7) { - throw new IllegalArgumentException("CSV line format is invalid. Expected 7 fields, got " + csvLine.length + "."); - } - - for (String value : csvLine) { - if (value == null || value.trim().isEmpty()) { - return null; - } - } - - String species = csvLine[0]; - if (!(species.equalsIgnoreCase("adelie") || species.equalsIgnoreCase("chinstrap") || species.equalsIgnoreCase("gentoo"))) { - throw new IllegalArgumentException("Invalid species: " + species); - } - - String island = csvLine[1]; - - double billLengthMm; - double billDepthMm; - int flipperLengthMm; - int bodyMassG; - - try { - billLengthMm = Double.parseDouble(csvLine[2]); - if (billLengthMm < 0) { - throw new IllegalArgumentException("Bill length cannot be negative."); - } - - billDepthMm = Double.parseDouble(csvLine[3]); - if (billDepthMm < 0) { - throw new IllegalArgumentException("Bill depth cannot be negative."); - } - - flipperLengthMm = Integer.parseInt(csvLine[4]); - if (flipperLengthMm < 0) { - throw new IllegalArgumentException("Flipper length cannot be negative."); - } - - bodyMassG = Integer.parseInt(csvLine[5]); - if (bodyMassG < 0) { - throw new IllegalArgumentException("Body mass cannot be negative."); - } - } catch (NumberFormatException e) { - throw new IllegalArgumentException("Invalid number format in CSV line", e); - } - - String sex = csvLine[6]; - if (!(sex.equalsIgnoreCase("male") || sex.equalsIgnoreCase("female"))) { - throw new IllegalArgumentException("Invalid sex: " + sex); - } - - return new Penguin(species, island, billLengthMm, billDepthMm, flipperLengthMm, bodyMassG, sex); - } - - @Override - public double[][] getInputs(List dataset) { - double[][] inputs = new double[dataset.size()][4]; - - for (int i = 0; i < dataset.size(); i++) { - Penguin p = dataset.get(i); - inputs[i][0] = p.billLengthMm(); - inputs[i][1] = p.billDepthMm(); - inputs[i][2] = p.flipperLengthMm(); - inputs[i][3] = p.bodyMassG(); - } - - return inputs; - } - - @Override - public double[][] getTargets(List dataset) { - double[][] targets = new double[dataset.size()][1]; - - for (int i = 0; i < dataset.size(); i++) { - Penguin p = dataset.get(i); - targets[i][0] = "Male".equalsIgnoreCase(p.sex()) ? 1.0 : 0.0; - } - - return targets; - } - - @Override - public String getDatasetDescription() { - return "Seaborn penguins dataset"; - } - - @Override - public double[][] getTrainFeatures() { - return featuresOf(getSplitedDataset().trainData()); - } - - @Override - public double[][] getTrainLabels() { - return labelsOf(getSplitedDataset().trainData()); - } - - @Override - public double[][] getTestFeatures() { - return featuresOf(getSplitedDataset().testData()); - } - - @Override - public double[][] getTestLabels() { - return labelsOf(getSplitedDataset().testData()); - } - - private double[][] featuresOf(List data) { - double[][] features = new double[data.size()][4]; - - for (int i = 0; i < data.size(); i++) { - Penguin p = data.get(i); - features[i][0] = p.billLengthMm(); - features[i][1] = p.billDepthMm(); - features[i][2] = p.flipperLengthMm(); - features[i][3] = p.bodyMassG(); - } - - return features; - } - - private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; - - for (int i = 0; i < data.size(); i++) { - Penguin p = data.get(i); - switch (p.species().toLowerCase()) { - case "adelie": - labels[i] = new double[]{1.0, 0.0, 0.0}; - break; - case "chinstrap": - labels[i] = new double[]{0.0, 1.0, 0.0}; - break; - case "gentoo": - labels[i] = new double[]{0.0, 0.0, 1.0}; - break; - default: - throw new IllegalArgumentException("Unbekannte Pinguinart: " + p.species()); - } - } - - return labels; - } - -} \ No newline at end of file diff --git a/example/src/main/java/de/example/data/seaborn/SeabornProvider.java b/example/src/main/java/de/example/data/seaborn/SeabornProvider.java deleted file mode 100644 index e2e5416..0000000 --- a/example/src/main/java/de/example/data/seaborn/SeabornProvider.java +++ /dev/null @@ -1,108 +0,0 @@ -package de.example.data.seaborn; - -import de.edux.data.provider.IDataProvider; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; - -public class SeabornProvider implements IDataProvider { - private static final Logger LOG = LoggerFactory.getLogger(SeabornProvider.class); - private final List dataset; - private final List trainingData; - private final List testData; - - public SeabornProvider(List dataset, List trainingData, List testData) { - this.dataset = dataset; - this.trainingData = trainingData; - this.testData = testData; - } - - @Override - public List getTrainData() { - return trainingData; - } - - @Override - public List getTestData() { - return testData; - } - - @Override - public void printStatistics() { - LOG.info("========================= Data Statistic =================="); - LOG.info("Dataset: Seaborn Penguins"); - LOG.info("Description: " + getDescription()); - LOG.info("Total dataset size: " + dataset.size()); - LOG.info("Training dataset size: " + trainingData.size()); - LOG.info("Test data set size: " + testData.size()); - LOG.info("Classes: " + getTrainLabels()[0].length); - LOG.info("==========================================================="); - } - - @Override - public Penguin getRandom(boolean equalDistribution) { - return dataset.get((int) (Math.random() * dataset.size())); - } - - @Override - public double[][] getTrainFeatures() { - return featuresOf(trainingData); - } - @Override - public double[][] getTrainLabels() { - return labelsOf(trainingData); - } - @Override - public double[][] getTestFeatures() { - return featuresOf(testData); - } - - @Override - public double[][] getTestLabels() { - return labelsOf(testData); - } - - private double[][] featuresOf(List data) { - double[][] features = new double[data.size()][4]; - - for (int i = 0; i < data.size(); i++) { - Penguin p = data.get(i); - features[i][0] = p.billLengthMm(); - features[i][1] = p.billDepthMm(); - features[i][2] = p.flipperLengthMm(); - features[i][3] = p.bodyMassG(); - } - - return features; - } - private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; - - for (int i = 0; i < data.size(); i++) { - Penguin p = data.get(i); - switch (p.species().toLowerCase()) { - case "adelie": - labels[i] = new double[]{1.0, 0.0, 0.0}; - break; - case "chinstrap": - labels[i] = new double[]{0.0, 1.0, 0.0}; - break; - case "gentoo": - labels[i] = new double[]{0.0, 0.0, 1.0}; - break; - default: - throw new IllegalArgumentException("Unbekannte Pinguinart: " + p.species()); - } - } - - return labels; - } - - - - @Override - public String getDescription() { - return "The Seaborn Penguin dataset comprises measurements and species classifications for penguins collected from three islands in the Palmer Archipelago, Antarctica."; - } -} diff --git a/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java b/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java deleted file mode 100644 index ddbcc05..0000000 --- a/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java +++ /dev/null @@ -1,28 +0,0 @@ -package de.example.decisiontree; - -import de.edux.ml.decisiontree.DecisionTree; -import de.example.data.iris.IrisProvider; - -import java.util.Arrays; - -public class DecisionTreeExample { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - - public static void main(String[] args) { - - var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.6); - datasetProvider.printStatistics(); - - double[][] features = datasetProvider.getTrainFeatures(); - double[][] labels = datasetProvider.getTrainLabels(); - - double[][] testFeatures = datasetProvider.getTestFeatures(); - double[][] testLabels = datasetProvider.getTestLabels(); - - DecisionTree decisionTree = new DecisionTree(8, 2, 1, 4); - decisionTree.train(features, labels); - decisionTree.evaluate(testFeatures, testLabels); - - } -} diff --git a/example/src/main/java/de/example/knn/KnnIrisExample.java b/example/src/main/java/de/example/knn/KnnIrisExample.java index e4d1850..de31f91 100644 --- a/example/src/main/java/de/example/knn/KnnIrisExample.java +++ b/example/src/main/java/de/example/knn/KnnIrisExample.java @@ -1,34 +1,36 @@ package de.example.knn; import de.edux.api.Classifier; +import de.edux.data.provider.DataProcessor; +import de.edux.data.reader.CSVIDataReader; import de.edux.ml.knn.KnnClassifier; -import de.edux.ml.nn.network.api.Dataset; -import de.example.data.iris.Iris; -import de.example.data.iris.IrisDataProcessor; import java.io.File; -import java.util.List; -/** - * Knn - K nearest neighbors - * Dataset: Iris - * First transfer the iris data into KnnPoints, use the variety as label. Then use the KnnClassifier to classify the test data. - */ public class KnnIrisExample { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; - private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; + private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); + private static final boolean SKIP_HEAD = true; public static void main(String[] args) { - var irisDataProcessor = new IrisDataProcessor(); - List data = irisDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - irisDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); + var featureColumnIndices = new int[]{0, 1, 2, 3}; + var targetColumnIndex = 4; + + var irisDataProcessor = new DataProcessor(new CSVIDataReader()) + .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex) + .normalize() + .shuffle() + .split(TRAIN_TEST_SPLIT_RATIO); + Classifier knn = new KnnClassifier(2); - knn.train(irisDataProcessor.getTrainFeatures(), irisDataProcessor.getTrainLabels()); - knn.evaluate(irisDataProcessor.getTestFeatures(), irisDataProcessor.getTestLabels()); + var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); + var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); + var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); + var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); + + knn.train(trainFeatures, trainLabels); + knn.evaluate(trainTestFeatures, trainTestLabels); } -} \ No newline at end of file +} diff --git a/example/src/main/java/de/example/knn/KnnSeabornExample.java b/example/src/main/java/de/example/knn/KnnSeabornExample.java deleted file mode 100644 index 324d438..0000000 --- a/example/src/main/java/de/example/knn/KnnSeabornExample.java +++ /dev/null @@ -1,37 +0,0 @@ -package de.example.knn; - -import de.edux.api.Classifier; -import de.edux.ml.knn.KnnClassifier; -import de.edux.ml.nn.network.api.Dataset; -import de.example.data.seaborn.Penguin; -import de.example.data.seaborn.SeabornDataProcessor; -import de.example.data.seaborn.SeabornProvider; - -import java.io.File; -import java.util.List; - -/** - * Knn - K nearest neighbors - * Dataset: Seaborn Penguins - */ -public class KnnSeabornExample { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; - private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; - private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); - - public static void main(String[] args) { - var seabornDataProcessor = new SeabornDataProcessor(); - List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - - Dataset dataset = seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); - var seabornProvider = new SeabornProvider(data, dataset.trainData(), dataset.testData()); - seabornProvider.printStatistics(); - - Classifier knn = new KnnClassifier(2); - knn.train(seabornProvider.getTrainFeatures(), seabornProvider.getTrainLabels()); - knn.evaluate(seabornProvider.getTestFeatures(), seabornProvider.getTestLabels()); - } - -} diff --git a/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java b/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java deleted file mode 100644 index 93655f1..0000000 --- a/example/src/main/java/de/example/nn/MultilayerPerceptronExample.java +++ /dev/null @@ -1,45 +0,0 @@ -package de.example.nn; - -import de.edux.functions.activation.ActivationFunction; -import de.edux.functions.initialization.Initialization; -import de.edux.functions.loss.LossFunction; -import de.edux.ml.nn.config.NetworkConfiguration; -import de.edux.ml.nn.network.MultilayerPerceptron; -import de.example.data.iris.IrisProvider; - -import java.util.List; - -public class MultilayerPerceptronExample { - - private static final boolean SHUFFLE = true; - private final static boolean NORMALIZE = true; - - public static void main(String[] args) { - var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.7); - datasetProvider.printStatistics(); - - double[][] features = datasetProvider.getTrainFeatures(); - double[][] labels = datasetProvider.getTrainLabels(); - - double[][] testFeatures = datasetProvider.getTestFeatures(); - double[][] testLabels = datasetProvider.getTestLabels(); - - //Configure Network with: - // - 4 Input Neurons - // - 2 Hidden Layer with 12 and 6 Neurons - // - 3 Output Neurons - // - Learning Rate of 0.1 - // - 1000 Epochs - // - Leaky ReLU as Activation Function for Hidden Layers - // - Softmax as Activation Function for Output Layer - // - Categorical Cross Entropy as Loss Function - // - Xavier as Weight Initialization for Hidden Layers - // - Xavier as Weight Initialization for Output Layer - NetworkConfiguration networkConfiguration = new NetworkConfiguration(features[0].length, List.of(32, 6), 3, 0.01, 1000, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); - - MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); - multilayerPerceptron.train(features, labels); - multilayerPerceptron.evaluate(testFeatures, testLabels); - } -} - diff --git a/example/src/main/java/de/example/nn/MultilayerPerceptronExampleV2.java b/example/src/main/java/de/example/nn/MultilayerPerceptronExampleV2.java new file mode 100644 index 0000000..7072a33 --- /dev/null +++ b/example/src/main/java/de/example/nn/MultilayerPerceptronExampleV2.java @@ -0,0 +1,58 @@ +package de.example.nn; + +import de.edux.data.provider.DataProcessor; +import de.edux.data.reader.CSVIDataReader; +import de.edux.functions.activation.ActivationFunction; +import de.edux.functions.initialization.Initialization; +import de.edux.functions.loss.LossFunction; +import de.edux.ml.nn.config.NetworkConfiguration; +import de.edux.ml.nn.network.MultilayerPerceptron; + +import java.io.File; +import java.util.List; + +public class MultilayerPerceptronExampleV2 { + + private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; + private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv"); + private static final boolean SKIP_HEAD = true; + + public static void main(String[] args) { + var featureColumnIndices = new int[]{0, 1, 2, 3}; + var targetColumnIndex = 4; + + var dataProcessor = new DataProcessor(new CSVIDataReader()); + var dataset = dataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex); + dataset.shuffle(); + dataset.normalize(); + dataProcessor.split(TRAIN_TEST_SPLIT_RATIO); + + + + var trainFeatures = dataProcessor.getTrainFeatures(featureColumnIndices); + var trainLabels = dataProcessor.getTrainLabels(targetColumnIndex); + var testFeatures = dataProcessor.getTestFeatures(featureColumnIndices); + var testLabels = dataProcessor.getTestLabels(targetColumnIndex); + + var classMap = dataProcessor.getClassMap(); + + System.out.println("Class Map: " + classMap); + + //Configure Network with: + // - 4 Input Neurons + // - 2 Hidden Layer with 12 and 6 Neurons + // - 3 Output Neurons + // - Learning Rate of 0.1 + // - 1000 Epochs + // - Leaky ReLU as Activation Function for Hidden Layers + // - Softmax as Activation Function for Output Layer + // - Categorical Cross Entropy as Loss Function + // - Xavier as Weight Initialization for Hidden Layers + // - Xavier as Weight Initialization for Output Layer + var networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128, 256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); + + MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); + multilayerPerceptron.train(trainFeatures, trainLabels); + multilayerPerceptron.evaluate(testFeatures, testLabels); + } +} diff --git a/example/src/main/java/de/example/nn/MultilayerPerceptronSeabornExample.java b/example/src/main/java/de/example/nn/MultilayerPerceptronSeabornExample.java deleted file mode 100644 index 9afcb6f..0000000 --- a/example/src/main/java/de/example/nn/MultilayerPerceptronSeabornExample.java +++ /dev/null @@ -1,44 +0,0 @@ -package de.example.nn; - - -import de.edux.api.Classifier; -import de.edux.functions.activation.ActivationFunction; -import de.edux.functions.initialization.Initialization; -import de.edux.functions.loss.LossFunction; -import de.edux.ml.nn.network.api.Dataset; -import de.example.data.seaborn.Penguin; -import de.edux.ml.nn.config.NetworkConfiguration; -import de.edux.ml.nn.network.MultilayerPerceptron; -import de.example.data.seaborn.SeabornDataProcessor; -import de.example.data.seaborn.SeabornProvider; - -import java.io.File; -import java.util.List; - -public class MultilayerPerceptronSeabornExample { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; - private static final double TRAIN_TEST_SPLIT_RATIO = 0.75; - private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv"); - - public static void main(String[] args) { - var seabornDataProcessor = new SeabornDataProcessor(); - List data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - - Dataset dataset = seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO); - var seabornProvider = new SeabornProvider(data, dataset.trainData(), dataset.testData()); - seabornProvider.printStatistics(); - double[][] features = seabornProvider.getTrainFeatures(); - double[][] labels = seabornProvider.getTrainLabels(); - - double[][] testFeatures = seabornProvider.getTestFeatures(); - double[][] testLabels = seabornProvider.getTestLabels(); - - NetworkConfiguration networkConfiguration = new NetworkConfiguration(features[0].length, List.of(32, 6), 3, 0.01, 1000, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); - Classifier multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); - multilayerPerceptron.train(features, labels); - multilayerPerceptron.evaluate(testFeatures, testLabels); - } - -} diff --git a/example/src/main/java/de/example/randomforest/RandomForestExample.java b/example/src/main/java/de/example/randomforest/RandomForestExample.java deleted file mode 100644 index c35080d..0000000 --- a/example/src/main/java/de/example/randomforest/RandomForestExample.java +++ /dev/null @@ -1,28 +0,0 @@ -package de.example.randomforest; - -import de.edux.api.Classifier; -import de.edux.ml.randomforest.RandomForest; -import de.example.data.iris.IrisProvider; - -public class RandomForestExample { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - - public static void main(String[] args) { - - var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.6); - datasetProvider.printStatistics(); - - double[][] trainFeatures = datasetProvider.getTrainFeatures(); - double[][] trainLabels = datasetProvider.getTrainLabels(); - - Classifier randomForest = new RandomForest(100, 10, 2, 1, 3, 60); - randomForest.train(trainFeatures, trainLabels); - - double[][] testFeatures = datasetProvider.getTestFeatures(); - double[][] testLabels = datasetProvider.getTestLabels(); - - randomForest.evaluate(testFeatures, testLabels); - - } -} diff --git a/example/src/main/java/de/example/svm/SVMExample.java b/example/src/main/java/de/example/svm/SVMExample.java deleted file mode 100644 index 7ef1984..0000000 --- a/example/src/main/java/de/example/svm/SVMExample.java +++ /dev/null @@ -1,31 +0,0 @@ -package de.example.svm; - -import de.edux.api.Classifier; -import de.edux.ml.svm.ISupportVectorMachine; -import de.edux.ml.svm.SVMKernel; -import de.edux.ml.svm.SupportVectorMachine; -import de.example.data.iris.IrisProvider; - -public class SVMExample { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - - public static void main(String[] args){ - var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.6); - datasetProvider.printStatistics(); - - var features = datasetProvider.getTrainFeatures(); - var labels = datasetProvider.getTrainLabels(); - - Classifier supportVectorMachine = new SupportVectorMachine(SVMKernel.LINEAR, 1); - - supportVectorMachine.train(features, labels); - - double[][] testFeatures = datasetProvider.getTestFeatures(); - double[][] testLabels = datasetProvider.getTestLabels(); - - supportVectorMachine.evaluate(testFeatures, testLabels); - } - - -} diff --git a/lib/src/main/java/de/edux/data/provider/DataNormalizer.java b/lib/src/main/java/de/edux/data/provider/DataNormalizer.java new file mode 100644 index 0000000..0df8972 --- /dev/null +++ b/lib/src/main/java/de/edux/data/provider/DataNormalizer.java @@ -0,0 +1,62 @@ +package de.edux.data.provider; + +import java.util.List; +import java.util.ArrayList; + +public class DataNormalizer implements Normalizer { + + @Override + public List normalize(List dataset) { + if (dataset == null || dataset.isEmpty()) { + return dataset; + } + + int columnCount = dataset.get(0).length; + + double[] minValues = new double[columnCount]; + double[] maxValues = new double[columnCount]; + boolean[] isNumericColumn = new boolean[columnCount]; + + for (int i = 0; i < columnCount; i++) { + minValues[i] = Double.MAX_VALUE; + maxValues[i] = -Double.MAX_VALUE; + isNumericColumn[i] = true; + } + + for (String[] row : dataset) { + for (int colIndex = 0; colIndex < columnCount; colIndex++) { + try { + double numValue = Double.parseDouble(row[colIndex]); + + if (numValue < minValues[colIndex]) { + minValues[colIndex] = numValue; + } + if (numValue > maxValues[colIndex]) { + maxValues[colIndex] = numValue; + } + } catch (NumberFormatException e) { + isNumericColumn[colIndex] = false; + } + } + } + + for (String[] row : dataset) { + for (int colIndex = 0; colIndex < columnCount; colIndex++) { + if (isNumericColumn[colIndex]) { + double numValue = Double.parseDouble(row[colIndex]); + double range = maxValues[colIndex] - minValues[colIndex]; + + if (range != 0.0) { + double normalized = (numValue - minValues[colIndex]) / range; + row[colIndex] = String.valueOf(normalized); + } else { + row[colIndex] = "0"; + } + } + } + } + + return dataset; + } + +} diff --git a/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java b/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java index 263276c..990b8f3 100644 --- a/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java +++ b/lib/src/main/java/de/edux/data/provider/DataPostProcessor.java @@ -1,24 +1,21 @@ package de.edux.data.provider; -import java.util.List; - -public abstract class DataPostProcessor { - public abstract void normalize(List rowDataset); +import de.edux.functions.imputation.ImputationStrategy; - public abstract T mapToDataRecord(String[] csvLine); +import java.util.List; - public abstract double[][] getInputs(List dataset); +public interface DataPostProcessor { + DataPostProcessor normalize(); - public abstract double[][] getTargets(List dataset); + DataPostProcessor shuffle(); - public abstract String getDatasetDescription(); + DataPostProcessor imputation(String columnName, ImputationStrategy imputationStrategy); - public abstract double[][] getTrainFeatures(); + DataPostProcessor imputation(int columnIndex, ImputationStrategy imputationStrategy); - public abstract double[][] getTrainLabels(); + List getDataset(); - public abstract double[][] getTestLabels(); + DataProcessor split(double splitRatio); - public abstract double[][] getTestFeatures(); } diff --git a/lib/src/main/java/de/edux/data/provider/DataProcessor.java b/lib/src/main/java/de/edux/data/provider/DataProcessor.java index e1f13bc..2bb7f55 100644 --- a/lib/src/main/java/de/edux/data/provider/DataProcessor.java +++ b/lib/src/main/java/de/edux/data/provider/DataProcessor.java @@ -1,81 +1,160 @@ package de.edux.data.provider; -import de.edux.data.reader.CSVIDataReader; import de.edux.data.reader.IDataReader; -import de.edux.ml.nn.network.api.Dataset; +import de.edux.functions.imputation.ImputationStrategy; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; +import java.util.*; -public abstract class DataProcessor extends DataPostProcessor implements IDataUtil { +public class DataProcessor implements DataPostProcessor, Dataset, DataloaderV2 { private static final Logger LOG = LoggerFactory.getLogger(DataProcessor.class); - private final IDataReader csvDataReader; - private ArrayList dataset; - private Dataset splitedDataset; + private String[] columnNames; + private final IDataReader dataReader; + private final Normalizer normalizer; + private List dataset; + private List trainData; + private List testData; - public DataProcessor() { - this.csvDataReader = new CSVIDataReader(); + @Override + public DataProcessor split(double splitRatio) { + int splitIndex = (int) (dataset.size() * splitRatio); + trainData = dataset.subList(0, splitIndex); + testData = dataset.subList(splitIndex, dataset.size()); + + return this; } - public DataProcessor(IDataReader csvDataReader) { - this.csvDataReader = csvDataReader; + public DataProcessor(IDataReader dataReader) { + this.dataReader = dataReader; + normalizer = new DataNormalizer(); } @Override - public List loadDataSetFromCSV(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, boolean filterIncompleteRecords) { - List x = csvDataReader.readFile(csvFile, csvSeparator); - List unmodifiableDataset = csvDataReader.readFile(csvFile, csvSeparator) - .stream() - .map(this::mapToDataRecord) - .filter(record -> !filterIncompleteRecords || record != null) - .toList(); - - dataset = new ArrayList<>(unmodifiableDataset); - LOG.info("Dataset loaded"); + public DataProcessor loadDataSetFromCSV(File csvFile, char csvSeparator, boolean skipHead, int[] inputColumns, int targetColumn) { + dataset = dataReader.readFile(csvFile, csvSeparator); - if (normalize) { - normalize(dataset); - LOG.info("Dataset normalized"); + if (skipHead) { + skipHead(); } - if (shuffle) { - Collections.shuffle(dataset); - LOG.info("Dataset shuffled"); + List uniqueClasses = new ArrayList<>(); + for (String[] row : dataset) { + String label = row[targetColumn]; + if (!uniqueClasses.contains(label)) { + uniqueClasses.add(label); + } } + + for (int i = 0; i < uniqueClasses.size(); i++) { + indexToClassMap.put(uniqueClasses.get(i), i); + } + + LOG.info("Dataset loaded"); + return this; + } + + private void skipHead() { + columnNames = dataset.remove(0); + } + + @Override + public DataPostProcessor normalize() { + this.dataset = this.normalizer.normalize(dataset); + return this; + } + + @Override + public DataPostProcessor shuffle() { + Collections.shuffle(dataset); + return this; + } + + + @Override + public List getDataset() { return dataset; } - /** - * Split data into train and test data - * - * @param data data to split - * @param trainTestSplitRatio ratio of train data - * @return list of train and test data. First element is train data, second element is test data. - */ @Override - public Dataset split(List data, double trainTestSplitRatio) { - if (trainTestSplitRatio < 0.0 || trainTestSplitRatio > 1.0) { - throw new IllegalArgumentException("Train-test split ratio must be between 0.0 and 1.0"); + public double[][] getInputs(List dataset, int[] inputColumns) { + if (dataset == null || dataset.isEmpty() || inputColumns == null || inputColumns.length == 0) { + throw new IllegalArgumentException("Did you call split() before?"); } - int trainSize = (int) (data.size() * trainTestSplitRatio); + int numRows = dataset.size(); + double[][] inputs = new double[numRows][inputColumns.length]; - List trainDataset = data.subList(0, trainSize); - List testDataset = data.subList(trainSize, data.size()); + for (int i = 0; i < numRows; i++) { + String[] row = dataset.get(i); + for (int j = 0; j < inputColumns.length; j++) { + int colIndex = inputColumns[j]; + try { + inputs[i][j] = Double.parseDouble(row[colIndex]); + } catch (NumberFormatException e) { + inputs[i][j] = 0; + } + } + } - splitedDataset = new Dataset<>(trainDataset, testDataset); - return splitedDataset; + return inputs; } - public ArrayList getDataset() { - return dataset; + + + private Map indexToClassMap = new HashMap<>(); + + @Override + public double[][] getTargets(List dataset, int targetColumn) { + if (dataset == null || dataset.isEmpty()) { + throw new IllegalArgumentException("Dataset darf nicht leer sein."); + } + + double[][] targets = new double[dataset.size()][indexToClassMap.size()]; + for (int i = 0; i < dataset.size(); i++) { + String value = dataset.get(i)[targetColumn]; + int index = indexToClassMap.get(value); + targets[i][index] = 1.0; + } + + return targets; } - public Dataset getSplitedDataset() { - return splitedDataset; + @Override + public Map getClassMap() { + return indexToClassMap; } -} + + @Override + public DataPostProcessor imputation(String columnName, ImputationStrategy imputationStrategy) { + return null; + } + + @Override + public DataPostProcessor imputation(int columnIndex, ImputationStrategy imputationStrategy) { + return null; + } + + + @Override + public double[][] getTrainFeatures( int[] inputColumns) { + return getInputs(trainData, inputColumns); + } + + @Override + public double[][] getTrainLabels( int targetColumn) { + return getTargets(trainData, targetColumn); + } + + @Override + public double[][] getTestFeatures( int[] inputColumns) { + return getInputs(testData, inputColumns); + } + + @Override + public double[][] getTestLabels( int targetColumn) { + return getTargets(testData, targetColumn); + } + +} diff --git a/lib/src/main/java/de/edux/data/provider/DataloaderV2.java b/lib/src/main/java/de/edux/data/provider/DataloaderV2.java new file mode 100644 index 0000000..bfd0377 --- /dev/null +++ b/lib/src/main/java/de/edux/data/provider/DataloaderV2.java @@ -0,0 +1,8 @@ +package de.edux.data.provider; + +import java.io.File; + +public interface DataloaderV2 { + DataProcessor loadDataSetFromCSV(File csvFile, char csvSeparator, boolean skipHead, int[] inputColumns, int targetColumn); +} + diff --git a/lib/src/main/java/de/edux/data/provider/Dataset.java b/lib/src/main/java/de/edux/data/provider/Dataset.java new file mode 100644 index 0000000..8433d4f --- /dev/null +++ b/lib/src/main/java/de/edux/data/provider/Dataset.java @@ -0,0 +1,23 @@ +package de.edux.data.provider; + +import java.util.List; +import java.util.Map; + +public interface Dataset { + + double[][] getInputs(List dataset, int[] inputColumns); + + double[][] getTargets(List dataset, int targetColumn); + + Map getClassMap(); + + double[][] getTrainFeatures(int[] inputColumns); + + double[][] getTrainLabels(int targetColumn); + + double[][] getTestLabels(int targetColumn); + + double[][] getTestFeatures(int[] inputColumns); + + +} diff --git a/lib/src/main/java/de/edux/data/provider/IDataProvider.java b/lib/src/main/java/de/edux/data/provider/IDataProvider.java deleted file mode 100644 index 767a76b..0000000 --- a/lib/src/main/java/de/edux/data/provider/IDataProvider.java +++ /dev/null @@ -1,24 +0,0 @@ -package de.edux.data.provider; - -import java.util.List; - -public interface IDataProvider { - List getTrainData(); - - List getTestData(); - - void printStatistics(); - - T getRandom(boolean equalDistribution); - - double[][] getTrainFeatures(); - - double[][] getTrainLabels(); - - double[][] getTestFeatures(); - - double[][] getTestLabels(); - - String getDescription(); - -} diff --git a/lib/src/main/java/de/edux/data/provider/IDataUtil.java b/lib/src/main/java/de/edux/data/provider/IDataUtil.java deleted file mode 100644 index 1f85516..0000000 --- a/lib/src/main/java/de/edux/data/provider/IDataUtil.java +++ /dev/null @@ -1,17 +0,0 @@ -package de.edux.data.provider; - -import de.edux.ml.nn.network.api.Dataset; - -import java.io.File; -import java.util.List; - -public interface IDataUtil { - List loadDataSetFromCSV(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, boolean filterIncompleteRecords); - - Dataset split(List dataset, double trainTestSplitRatio); - - double[][] getInputs(List dataset); - - double[][] getTargets(List dataset); - -} diff --git a/lib/src/main/java/de/edux/data/provider/Normalizer.java b/lib/src/main/java/de/edux/data/provider/Normalizer.java new file mode 100644 index 0000000..0f0b3be --- /dev/null +++ b/lib/src/main/java/de/edux/data/provider/Normalizer.java @@ -0,0 +1,7 @@ +package de.edux.data.provider; + +import java.util.List; + +public interface Normalizer { + List normalize(List dataset); +} diff --git a/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java b/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java index 6b14067..d6a010e 100644 --- a/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java +++ b/lib/src/main/java/de/edux/data/reader/CSVIDataReader.java @@ -19,7 +19,6 @@ public List readFile(File file, char separator) { try(CSVReader reader = new CSVReaderBuilder( new FileReader(file)) .withCSVParser(customCSVParser) - .withSkipLines(1) .build()){ result = reader.readAll(); } catch (CsvException | IOException e) { diff --git a/lib/src/main/java/de/edux/functions/imputation/ImputationStrategy.java b/lib/src/main/java/de/edux/functions/imputation/ImputationStrategy.java new file mode 100644 index 0000000..1c411c1 --- /dev/null +++ b/lib/src/main/java/de/edux/functions/imputation/ImputationStrategy.java @@ -0,0 +1,5 @@ +package de.edux.functions.imputation; + +public enum ImputationStrategy { + DUMMY, MEAN, AVERAGE, MODE; +} diff --git a/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java index a0b3af0..15a46d2 100644 --- a/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java +++ b/lib/src/test/java/de/edux/data/provider/DataProcessorTest.java @@ -1,152 +1,159 @@ package de.edux.data.provider; -import de.edux.data.reader.CSVIDataReader; -import de.edux.ml.nn.network.api.Dataset; +import de.edux.data.reader.IDataReader; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.InjectMocks; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.mockito.junit.jupiter.MockitoExtension; import java.io.File; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; +import java.util.Map; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyChar; import static org.mockito.Mockito.when; - @ExtendWith(MockitoExtension.class) class DataProcessorTest { - @InjectMocks - private DataProcessor dataProcessor = getDummyDataUtil(); + + private static final boolean SKIP_HEAD = true; + private List dummyDataset; + + private DataProcessor dataProcessor; @Mock - private CSVIDataReader csvDataReader; + IDataReader dataReader; @BeforeEach void setUp() { - dataProcessor = getDummyDataUtil(); + dummyDataset = new ArrayList<>(); + dummyDataset.add(new String[]{"col1", "col2", "Name", "col4", "col5"}); + dummyDataset.add(new String[]{"1", "2", "3", "Anna", "5"}); + dummyDataset.add(new String[]{"6", "7", "8", "Nina", "10"}); + dummyDataset.add(new String[]{"11", "12", "13", "Johanna", "15"}); + dummyDataset.add(new String[]{"16", "17", "18", "Isabela", "20"}); + when(dataReader.readFile(any(), anyChar())).thenReturn(dummyDataset); + dataProcessor = new DataProcessor(dataReader); } @Test - void testSplitWithValidRatio() { - List dataset = Arrays.asList("A", "B", "C", "D", "E"); - double trainTestSplitRatio = 0.6; - - Dataset result = dataProcessor.split(dataset, trainTestSplitRatio); + void shouldSkipHead() { + dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[]{0, 1, 2, 4}, 3); + assertEquals(4, dataProcessor.getDataset().size(), "Number of rows does not match."); + } - assertEquals(3, result.trainData().size(), "Train dataset size should be 3"); - assertEquals(2, result.testData().size(), "Test dataset size should be 2"); + @Test + void shouldNotSkipHead() { + dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', false, new int[]{0, 1, 2, 4}, 3); + assertEquals(5, dataProcessor.getDataset().size(), "Number of rows does not match."); } + @Test - void testSplitWithZeroRatio() { - List dataset = Arrays.asList("A", "B", "C", "D", "E"); - double trainTestSplitRatio = 0.0; + void getTargets() { + when(dataReader.readFile(any(), anyChar())).thenReturn(dummyDataset); + dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[]{0, 1, 2, 4}, 3) + .split(0.5); + dummyDataset.add(new String[]{"21", "22", "23", "Isabela", "25"}); + + double[][] targets = dataProcessor.getTargets(dummyDataset, 3); + double[][] expectedTargets = { + {1.0, 0.0, 0.0, 0.0}, // Anna + {0.0, 1.0, 0.0, 0.0}, // Nina + {0.0, 0.0, 1.0, 0.0}, // Johanna + {0.0, 0.0, 0.0, 1.0}, // Isabela + {0.0, 0.0, 0.0, 1.0} // Isabela + }; + + for (int i = 0; i < expectedTargets.length; i++) { + assertArrayEquals(expectedTargets[i], targets[i], "Die Zielzeile " + i + " stimmt nicht überein."); + } - Dataset result = dataProcessor.split(dataset, trainTestSplitRatio); + Map classMap = dataProcessor.getClassMap(); + Map expectedClassMap = Map.of( + "Anna", 0, + "Nina", 1, + "Johanna", 2, + "Isabela", 3); - assertEquals(0, result.trainData().size(), "Train dataset size should be 0"); - assertEquals(5, result.testData().size(), "Test dataset size should be 5"); + assertEquals(expectedClassMap, classMap, "Die Klassen stimmen nicht überein."); } @Test - void testSplitWithFullRatio() { - List dataset = Arrays.asList("A", "B", "C", "D", "E"); - double trainTestSplitRatio = 1.0; + void getInputs() { + dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[]{0, 1, 2, 4}, 3) + .split(0.5); + double[][] inputs = dataProcessor.getInputs(dummyDataset, new int[]{0, 1, 2, 4}); + + double[][] expectedInputs = { + {1.0, 2.0, 3.0, 5.0}, + {6.0, 7.0, 8.0, 10.0}, + {11.0, 12.0, 13.0, 15.0}, + {16.0, 17.0, 18.0, 20.0} + }; + + assertEquals(expectedInputs.length, inputs.length, "Die Anzahl der Zeilen stimmt nicht überein."); - Dataset result = dataProcessor.split(dataset, trainTestSplitRatio); + for (int i = 0; i < expectedInputs.length; i++) { + assertArrayEquals(expectedInputs[i], inputs[i], "Die Zeile " + i + " entspricht nicht den erwarteten Werten."); + } + } - assertEquals(5, result.trainData().size(), "Train dataset size should be 5"); - assertEquals(0, result.testData().size(), "Test dataset size should be 0"); + private List duplicateList(List list) { + List duplicate = new ArrayList<>(); + for (String[] row : list) { + duplicate.add(row.clone()); + } + return duplicate; } @Test - void testSplitWithInvalidNegativeRatio() { - List dataset = Arrays.asList("A", "B", "C", "D", "E"); - double trainTestSplitRatio = -0.1; + void shouldNormalize() { + dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', SKIP_HEAD, new int[]{0, 1, 2, 4}, 3) + .split(0.5); + List normalizedDataset = dataProcessor.normalize().getDataset(); + + String[][] expectedNormalizedValues = { + {"0.0", "0.0", "0.0", "Anna", "0.0"}, + {"0.3333333333333333", "0.3333333333333333", "0.3333333333333333", "Nina", "0.3333333333333333"}, + {"0.6666666666666666", "0.6666666666666666", "0.6666666666666666", "Johanna", "0.6666666666666666"}, + {"1.0", "1.0", "1.0", "Isabela", "1.0"} + }; - assertThrows(IllegalArgumentException.class, () -> dataProcessor.split(dataset, trainTestSplitRatio)); + for (int i = 1; i < normalizedDataset.size(); i++) { + String[] row = normalizedDataset.get(i); + assertArrayEquals(expectedNormalizedValues[i], row, "Die Zeile " + i + " entspricht nicht den erwarteten normalisierten Werten."); + } } @Test - void testSplitWithInvalidAboveOneRatio() { - List dataset = Arrays.asList("A", "B", "C", "D", "E"); - double trainTestSplitRatio = 1.1; + void shouldShuffle() { + List originalDataset = duplicateList(dummyDataset); + dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', false, new int[]{0, 1, 2, 4}, 3) + .split(0.5); + List shuffledDataset = dataProcessor.shuffle().getDataset(); - assertThrows(IllegalArgumentException.class, () -> dataProcessor.split(dataset, trainTestSplitRatio)); + assertNotEquals(originalDataset, shuffledDataset, "Die Reihenfolge der Zeilen hat sich nicht geändert."); } - @Test - void testLoadTDataSetWithoutNormalizationAndShuffling() { - File dummyFile = new File("dummy.csv"); - char separator = ','; - String[] csvFirstLine = {"A", "B", "C", "D", "E"}; - String[] csvSecondLine = {"F", "G", "H", "I", "J"}; - List csvLine = new ArrayList<>(); - csvLine.add(csvFirstLine); - csvLine.add(csvSecondLine); - when(csvDataReader.readFile(any(), anyChar())).thenReturn(csvLine); + @Test + void shouldReturnTrainTestDataset() { + dataProcessor.loadDataSetFromCSV(new File("mockpathhere"), ',', false, new int[]{0, 1, 2, 4}, 3); + dataProcessor.split(0.5); - List result = dataProcessor.loadDataSetFromCSV(dummyFile, separator, false, false, false); + int[] inputColumns = new int[]{0, 1, 2, 4}; + double[][] trainFeatures = dataProcessor.getTrainFeatures(inputColumns); + double[][] testFeatures = dataProcessor.getTestFeatures(inputColumns); - assertEquals(2, result.size(), "Dataset size should be 2"); - } + double[][] trainLabels = dataProcessor.getTrainLabels(3); + double[][] testLabels = dataProcessor.getTestLabels(3); - private DataProcessor getDummyDataUtil() { - return new DataProcessor<>(csvDataReader) { - - @Override - public void normalize(List dataset) { - } - - @Override - public String mapToDataRecord(String[] csvLine) { - return null; - } - - @Override - public double[][] getInputs(List dataset) { - return new double[0][]; - } - - @Override - public double[][] getTargets(List dataset) { - return new double[0][]; - } - - @Override - public String getDatasetDescription() { - return null; - } - - @Override - public double[][] getTrainFeatures() { - return new double[0][]; - } - - @Override - public double[][] getTrainLabels() { - return new double[0][]; - } - - @Override - public double[][] getTestLabels() { - return new double[0][]; - } - - @Override - public double[][] getTestFeatures() { - return new double[0][]; - } - }; } -} +} \ No newline at end of file diff --git a/lib/src/test/java/de/edux/data/provider/Penguin.java b/lib/src/test/java/de/edux/data/provider/Penguin.java deleted file mode 100644 index 5228024..0000000 --- a/lib/src/test/java/de/edux/data/provider/Penguin.java +++ /dev/null @@ -1,10 +0,0 @@ -package de.edux.data.provider; - -/** - * SeaBorn penguin dto - */ -public record Penguin(String species, String island, double billLengthMm, double billDepthMm, int flipperLengthMm, int bodyMassG, String sex){ - public double[] getFeatures() { - return new double[]{billLengthMm, billDepthMm, flipperLengthMm, bodyMassG}; - } -} diff --git a/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java b/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java deleted file mode 100644 index bc08ee8..0000000 --- a/lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java +++ /dev/null @@ -1,182 +0,0 @@ -package de.edux.data.provider; - -import java.util.ArrayList; -import java.util.List; - -public class SeabornDataProcessor extends DataProcessor { - @Override - public void normalize(List penguins) { - double maxBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).max().orElse(1); - double minBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).min().orElse(0); - - double maxBillDepth = penguins.stream().mapToDouble(Penguin::billDepthMm).max().orElse(1); - double minBillDepth = penguins.stream().mapToDouble(Penguin::billDepthMm).min().orElse(0); - - double maxFlipperLength = penguins.stream().mapToInt(Penguin::flipperLengthMm).max().orElse(1); - double minFlipperLength = penguins.stream().mapToInt(Penguin::flipperLengthMm).min().orElse(0); - - double maxBodyMass = penguins.stream().mapToInt(Penguin::bodyMassG).max().orElse(1); - double minBodyMass = penguins.stream().mapToInt(Penguin::bodyMassG).min().orElse(0); - - List normalizedPenguins = new ArrayList<>(); - for (Penguin p : penguins) { - double normalizedBillLength = (p.billLengthMm() - minBillLength) / (maxBillLength - minBillLength); - double normalizedBillDepth = (p.billDepthMm() - minBillDepth) / (maxBillDepth - minBillDepth); - double normalizedFlipperLength = (p.flipperLengthMm() - minFlipperLength) / (maxFlipperLength - minFlipperLength); - double normalizedBodyMass = (p.bodyMassG() - minBodyMass) / (maxBodyMass - minBodyMass); - - p = new Penguin(p.species(), p.island(), normalizedBillLength, normalizedBillDepth, (int) normalizedFlipperLength, (int) normalizedBodyMass, p.sex()); - - Penguin normalizedPenguin = new Penguin(p.species(), p.island(), normalizedBillLength, normalizedBillDepth, (int) normalizedFlipperLength, (int) normalizedBodyMass, p.sex()); - normalizedPenguins.add(normalizedPenguin); - } - penguins.clear(); - penguins.addAll(normalizedPenguins); - } - - @Override - public Penguin mapToDataRecord(String[] csvLine) { - if (csvLine.length != 7) { - throw new IllegalArgumentException("CSV line format is invalid. Expected 7 fields, got " + csvLine.length + "."); - } - - for (String value : csvLine) { - if (value == null || value.trim().isEmpty()) { - return null; - } - } - - String species = csvLine[0]; - if (!(species.equalsIgnoreCase("adelie") || species.equalsIgnoreCase("chinstrap") || species.equalsIgnoreCase("gentoo"))) { - throw new IllegalArgumentException("Invalid species: " + species); - } - - String island = csvLine[1]; - - double billLengthMm; - double billDepthMm; - int flipperLengthMm; - int bodyMassG; - - try { - billLengthMm = Double.parseDouble(csvLine[2]); - if (billLengthMm < 0) { - throw new IllegalArgumentException("Bill length cannot be negative."); - } - - billDepthMm = Double.parseDouble(csvLine[3]); - if (billDepthMm < 0) { - throw new IllegalArgumentException("Bill depth cannot be negative."); - } - - flipperLengthMm = Integer.parseInt(csvLine[4]); - if (flipperLengthMm < 0) { - throw new IllegalArgumentException("Flipper length cannot be negative."); - } - - bodyMassG = Integer.parseInt(csvLine[5]); - if (bodyMassG < 0) { - throw new IllegalArgumentException("Body mass cannot be negative."); - } - } catch (NumberFormatException e) { - throw new IllegalArgumentException("Invalid number format in CSV line", e); - } - - String sex = csvLine[6]; - if (!(sex.equalsIgnoreCase("male") || sex.equalsIgnoreCase("female"))) { - throw new IllegalArgumentException("Invalid sex: " + sex); - } - - return new Penguin(species, island, billLengthMm, billDepthMm, flipperLengthMm, bodyMassG, sex); - } - - @Override - public double[][] getInputs(List dataset) { - double[][] inputs = new double[dataset.size()][4]; - - for (int i = 0; i < dataset.size(); i++) { - Penguin p = dataset.get(i); - inputs[i][0] = p.billLengthMm(); - inputs[i][1] = p.billDepthMm(); - inputs[i][2] = p.flipperLengthMm(); - inputs[i][3] = p.bodyMassG(); - } - - return inputs; - } - - @Override - public double[][] getTargets(List dataset) { - double[][] targets = new double[dataset.size()][1]; - - for (int i = 0; i < dataset.size(); i++) { - Penguin p = dataset.get(i); - targets[i][0] = "Male".equalsIgnoreCase(p.sex()) ? 1.0 : 0.0; - } - - return targets; - } - - @Override - public String getDatasetDescription() { - return "Seaborn penguins dataset"; - } - - @Override - public double[][] getTrainFeatures() { - return featuresOf(getSplitedDataset().trainData()); - } - - @Override - public double[][] getTrainLabels() { - return labelsOf(getSplitedDataset().trainData()); - } - - @Override - public double[][] getTestFeatures() { - return featuresOf(getSplitedDataset().testData()); - } - - @Override - public double[][] getTestLabels() { - return labelsOf(getSplitedDataset().testData()); - } - - private double[][] featuresOf(List data) { - double[][] features = new double[data.size()][4]; - - for (int i = 0; i < data.size(); i++) { - Penguin p = data.get(i); - features[i][0] = p.billLengthMm(); - features[i][1] = p.billDepthMm(); - features[i][2] = p.flipperLengthMm(); - features[i][3] = p.bodyMassG(); - } - - return features; - } - - private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; - - for (int i = 0; i < data.size(); i++) { - Penguin p = data.get(i); - switch (p.species().toLowerCase()) { - case "adelie": - labels[i] = new double[]{1.0, 0.0, 0.0}; - break; - case "chinstrap": - labels[i] = new double[]{0.0, 1.0, 0.0}; - break; - case "gentoo": - labels[i] = new double[]{0.0, 0.0, 1.0}; - break; - default: - throw new IllegalArgumentException("Unbekannte Pinguinart: " + p.species()); - } - } - - return labels; - } - -} \ No newline at end of file diff --git a/lib/src/test/java/de/edux/data/provider/SeabornProvider.java b/lib/src/test/java/de/edux/data/provider/SeabornProvider.java deleted file mode 100644 index de98c57..0000000 --- a/lib/src/test/java/de/edux/data/provider/SeabornProvider.java +++ /dev/null @@ -1,109 +0,0 @@ -package de.edux.data.provider; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; - -public class SeabornProvider implements IDataProvider { - private static final Logger LOG = LoggerFactory.getLogger(SeabornProvider.class); - private final List dataset; - private final List trainingData; - private final List testData; - - public SeabornProvider(List dataset, List trainingData, List testData) { - this.dataset = dataset; - this.trainingData = trainingData; - this.testData = testData; - } - - @Override - public List getTrainData() { - return trainingData; - } - - @Override - public List getTestData() { - return testData; - } - - @Override - public void printStatistics() { - LOG.info("========================= Data Statistic =================="); - LOG.info("Dataset: Seaborn Penguins"); - LOG.info("Description: " + getDescription()); - LOG.info("Total dataset size: " + dataset.size()); - LOG.info("Training dataset size: " + trainingData.size()); - LOG.info("Test data set size: " + testData.size()); - LOG.info("Classes: " + getTrainLabels()[0].length); - LOG.info("==========================================================="); - } - - @Override - public Penguin getRandom(boolean equalDistribution) { - return dataset.get((int) (Math.random() * dataset.size())); - } - - @Override - public double[][] getTrainFeatures() { - return featuresOf(trainingData); - } - - private double[][] featuresOf(List data) { - double[][] features = new double[data.size()][4]; - - for (int i = 0; i < data.size(); i++) { - Penguin p = data.get(i); - features[i][0] = p.billLengthMm(); - features[i][1] = p.billDepthMm(); - features[i][2] = p.flipperLengthMm(); - features[i][3] = p.bodyMassG(); - } - - return features; - } - - - @Override - public double[][] getTrainLabels() { - return labelsOf(trainingData); - } - - private double[][] labelsOf(List data) { - double[][] labels = new double[data.size()][3]; - - for (int i = 0; i < data.size(); i++) { - Penguin p = data.get(i); - switch (p.species().toLowerCase()) { - case "adelie": - labels[i] = new double[]{1.0, 0.0, 0.0}; - break; - case "chinstrap": - labels[i] = new double[]{0.0, 1.0, 0.0}; - break; - case "gentoo": - labels[i] = new double[]{0.0, 0.0, 1.0}; - break; - default: - throw new IllegalArgumentException("Unbekannte Pinguinart: " + p.species()); - } - } - - return labels; - } - - @Override - public double[][] getTestFeatures() { - return featuresOf(testData); - } - - @Override - public double[][] getTestLabels() { - return labelsOf(testData); - } - - @Override - public String getDescription() { - return "The Seaborn Penguin dataset comprises measurements and species classifications for penguins collected from three islands in the Palmer Archipelago, Antarctica."; - } -} diff --git a/lib/src/test/java/de/edux/ml/RandomForestTest.java b/lib/src/test/java/de/edux/ml/RandomForestTest.java deleted file mode 100644 index a2a6c4c..0000000 --- a/lib/src/test/java/de/edux/ml/RandomForestTest.java +++ /dev/null @@ -1,62 +0,0 @@ -package de.edux.ml; - -import de.edux.api.Classifier; -import de.edux.data.provider.Penguin; -import de.edux.data.provider.SeabornDataProcessor; -import de.edux.data.provider.SeabornProvider; -import de.edux.ml.randomforest.RandomForest; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -import java.io.File; -import java.net.URL; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -class RandomForestTest { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; - private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; - private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; - private static SeabornProvider seabornProvider; - @BeforeAll - static void setup() { - URL url = RandomForestTest.class.getClassLoader().getResource(CSV_FILE_PATH); - if (url == null) { - throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH); - } - File csvFile = new File(url.getPath()); - var seabornDataProcessor = new SeabornDataProcessor(); - var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - var splitedDataset = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); - seabornProvider = new SeabornProvider(dataset, splitedDataset.trainData(), splitedDataset.testData()); - } - @Test - void train() { - double[][] features = seabornProvider.getTrainFeatures(); - double[][] labels = seabornProvider.getTrainLabels(); - - double[][] testFeatures = seabornProvider.getTestFeatures(); - double[][] testLabels = seabornProvider.getTestLabels(); - - assertTrue(features.length > 0); - assertTrue(labels.length > 0); - assertTrue(testFeatures.length > 0); - assertTrue(testLabels.length > 0); - - int numberOfTrees = 100; - int maxDepth = 8; - int minSampleSize = 2; - int minSamplesLeaf = 1; - int maxLeafNodes = 2; - int numFeatures = (int) Math.sqrt(features.length)*3; - - Classifier randomForest = new RandomForest( numberOfTrees, maxDepth, minSampleSize, minSamplesLeaf, maxLeafNodes,numFeatures); - randomForest.train( features, labels); - double accuracy = randomForest.evaluate(testFeatures, testLabels); - System.out.println(accuracy); - assertTrue(accuracy>0.7); - } -} \ No newline at end of file diff --git a/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java b/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java deleted file mode 100644 index 537f19c..0000000 --- a/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java +++ /dev/null @@ -1,71 +0,0 @@ -package de.edux.ml.decisiontree; - -import de.edux.data.provider.Penguin; -import de.edux.data.provider.SeabornDataProcessor; -import de.edux.data.provider.SeabornProvider; -import de.edux.ml.nn.network.api.Dataset; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.RepeatedTest; -import org.junit.jupiter.api.Test; - -import java.io.File; -import java.net.URL; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -class DecisionTreeTest { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; - private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; - private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; - private static SeabornProvider seabornProvider; - @BeforeAll - static void setup() { - URL url = DecisionTreeTest.class.getClassLoader().getResource(CSV_FILE_PATH); - if (url == null) { - throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH); - } - File csvFile = new File(url.getPath()); - var seabornDataProcessor = new SeabornDataProcessor(); - var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - Dataset splitedDataset = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); - seabornProvider = new SeabornProvider(dataset, splitedDataset.trainData(), splitedDataset.testData()); - } - - @RepeatedTest(5) - void train() { - double[][] features = seabornProvider.getTrainFeatures(); - double[][] labels = seabornProvider.getTrainLabels(); - - double[][] testFeatures = seabornProvider.getTestFeatures(); - double[][] testLabels = seabornProvider.getTestLabels(); - - assertTrue(features.length > 0); - assertTrue(labels.length > 0); - assertTrue(testFeatures.length > 0); - assertTrue(testLabels.length > 0); - int maxDepth = 12; - int minSampleSplit = 2; - int minSampleLeaf = 1; - int maxLeafNodes = 8; - DecisionTree decisionTree = new DecisionTree(maxDepth, minSampleSplit, minSampleLeaf, maxLeafNodes); - - decisionTree.train(features, labels); - double accuracy = decisionTree.evaluate(testFeatures, testLabels); - assertTrue(accuracy>0.7); - } - - @Test - void predict() { - } - - @Test - void evaluate() { - } - - @Test - void getFeatureImportance() { - } -} \ No newline at end of file diff --git a/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java b/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java deleted file mode 100644 index b4d4763..0000000 --- a/lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java +++ /dev/null @@ -1,61 +0,0 @@ -package de.edux.ml.nn.network; - -import de.edux.data.provider.SeabornDataProcessor; -import de.edux.data.provider.SeabornProvider; -import de.edux.functions.activation.ActivationFunction; -import de.edux.functions.initialization.Initialization; -import de.edux.functions.loss.LossFunction; -import de.edux.ml.nn.config.NetworkConfiguration; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.RepeatedTest; - -import java.io.File; -import java.net.URL; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -class MultilayerPerceptronTest { - private static final boolean SHUFFLE = true; - private static final boolean NORMALIZE = true; - private static final boolean FILTER_INCOMPLETE_RECORDS = true; - private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; - private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; - private SeabornProvider seabornProvider; - - @BeforeEach - void setUp() { - URL url = MultilayerPerceptronTest.class.getClassLoader().getResource(CSV_FILE_PATH); - if (url == null) { - throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH); - } - File csvFile = new File(url.getPath()); - var seabornDataProcessor = new SeabornDataProcessor(); - var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); - var trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); - seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.trainData(), trainTestSplittedList.testData()); - - } - - @RepeatedTest(3) - void shouldReachModelAccuracyAtLeast70() { - double[][] features = seabornProvider.getTrainFeatures(); - double[][] labels = seabornProvider.getTrainLabels(); - - double[][] testFeatures = seabornProvider.getTestFeatures(); - double[][] testLabels = seabornProvider.getTestLabels(); - - assertTrue(features.length > 0); - assertTrue(labels.length > 0); - assertTrue(testFeatures.length > 0); - assertTrue(testLabels.length > 0); - - NetworkConfiguration networkConfiguration = new NetworkConfiguration(features[0].length, List.of(128,256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); - - MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels); - multilayerPerceptron.train(features, labels); - double accuracy = multilayerPerceptron.evaluate(testFeatures, testLabels); - assertTrue(accuracy > 0.7); - - } -} \ No newline at end of file