Skip to content

Commit

Permalink
Merge pull request #66 from Samyssmile/drop-custom-data
Browse files Browse the repository at this point in the history
feat(#64): Drop Custom Data Preparation #64
  • Loading branch information
Samyssmile authored Oct 25, 2023
2 parents e29baac + 0140530 commit 44e1b8f
Show file tree
Hide file tree
Showing 33 changed files with 447 additions and 1,618 deletions.
36 changes: 20 additions & 16 deletions example/src/main/java/de/example/benchmark/Benchmark.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package de.example.benchmark;

import de.edux.api.Classifier;
import de.edux.data.provider.DataProcessor;
import de.edux.data.reader.CSVIDataReader;
import de.edux.functions.activation.ActivationFunction;
import de.edux.functions.initialization.Initialization;
import de.edux.functions.loss.LossFunction;
Expand All @@ -11,8 +13,6 @@
import de.edux.ml.randomforest.RandomForest;
import de.edux.ml.svm.SVMKernel;
import de.edux.ml.svm.SupportVectorMachine;
import de.example.data.seaborn.Penguin;
import de.example.data.seaborn.SeabornDataProcessor;

import java.io.File;
import java.util.ArrayList;
Expand All @@ -25,11 +25,10 @@
* Compare the performance of different classifiers
*/
public class Benchmark {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean FILTER_INCOMPLETE_RECORDS = true;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.75;
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "seaborn-penguins" + File.separator + "penguins.csv");
private static final double TRAIN_TEST_SPLIT_RATIO = 0.70;
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv");
private static final boolean SKIP_HEAD = true;

private double[][] trainFeatures;
private double[][] trainLabels;
private double[][] testFeatures;
Expand All @@ -49,7 +48,7 @@ private void run() {
Classifier randomForest = new RandomForest(100, 10, 2, 1, 3, 60);
Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 1);

networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128,256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER);
networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128, 256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER);
multilayerPerceptron = new MultilayerPerceptron(networkConfiguration, testFeatures, testLabels);
Map<String, Classifier> classifiers = Map.of(
"KNN", knn,
Expand All @@ -67,7 +66,7 @@ private void run() {
results.put("MLP", new ArrayList<>());


IntStream.range(0, 50).forEach(i -> {
IntStream.range(0, 1).forEach(i -> {
knn.train(trainFeatures, trainLabels);
decisionTree.train(trainFeatures, trainLabels);
randomForest.train(trainFeatures, trainLabels);
Expand Down Expand Up @@ -126,16 +125,21 @@ private void updateMLP(double[][] testFeatures, double[][] testLabels) {
}

private void initFeaturesAndLabels() {
var seabornDataProcessor = new SeabornDataProcessor();
List<Penguin> data = seabornDataProcessor.loadDataSetFromCSV(CSV_FILE, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS);
seabornDataProcessor.split(data, TRAIN_TEST_SPLIT_RATIO);
var featureColumnIndices = new int[]{0, 1, 2, 3};
var targetColumnIndex = 4;

var dataProcessor = new DataProcessor(new CSVIDataReader())
.loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex)
.normalize()
.shuffle()
.split(TRAIN_TEST_SPLIT_RATIO);

trainFeatures = seabornDataProcessor.getTrainFeatures();
trainLabels = seabornDataProcessor.getTrainLabels();

testFeatures = seabornDataProcessor.getTestFeatures();
testLabels = seabornDataProcessor.getTestLabels();

trainFeatures = dataProcessor.getTrainFeatures(featureColumnIndices);
trainLabels = dataProcessor.getTrainLabels(targetColumnIndex);
testFeatures = dataProcessor.getTestFeatures(featureColumnIndices);
testLabels = dataProcessor.getTestLabels(targetColumnIndex);

}
}
33 changes: 0 additions & 33 deletions example/src/main/java/de/example/data/iris/Iris.java

This file was deleted.

133 changes: 0 additions & 133 deletions example/src/main/java/de/example/data/iris/IrisDataProcessor.java

This file was deleted.

93 changes: 0 additions & 93 deletions example/src/main/java/de/example/data/iris/IrisDataUtil.java

This file was deleted.

Loading

0 comments on commit 44e1b8f

Please sign in to comment.