Skip to content

Commit

Permalink
Merge pull request #27 from Samyssmile/test/multilayerPerceptronTest
Browse files Browse the repository at this point in the history
test(26): Write jUnit Tests for Multilayer NeuralNetwork #26
  • Loading branch information
Samyssmile authored Oct 6, 2023
2 parents 892ed25 + 85927ca commit 8e2fa88
Show file tree
Hide file tree
Showing 9 changed files with 654 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ public static void main(String[] args) {

MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(features, labels, testFeatures, testLabels, networkConfiguration);
multilayerPerceptron.train();
multilayerPerceptron.evaluate(testFeatures, testLabels);
}
}

3 changes: 1 addition & 2 deletions lib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ dependencies {
implementation 'org.apache.logging.log4j:log4j-slf4j-impl:2.20.0'
api 'org.apache.commons:commons-math3:3.6.1'

//test
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.0'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.10.0'
testImplementation 'org.mockito:mockito-core:5.5.0'
testImplementation 'org.mockito:mockito-junit-jupiter:5.5.0'
testImplementation 'org.junit.jupiter:junit-jupiter:5.8.1'
}

testing {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class MultilayerPerceptron {
private final double[][] testTargets;
private final List<Neuron[]> hiddenLayers;
private final Neuron[] outputLayer;
private double bestAccuracy;

public MultilayerPerceptron(double[][] inputs, double[][] targets, double[][] testInputs, double[][] testTargets, NetworkConfiguration config) {
this.inputs = inputs;
Expand Down Expand Up @@ -73,7 +74,7 @@ private double[] feedforward(double[] input) {
}

public void train() {
double bestAccuracy = 0;
bestAccuracy = 0;
for (int epoch = 0; epoch < config.epochs(); epoch++) {
for (int i = 0; i < inputs.length; i++) {
double[] output = feedforward(inputs[i]);
Expand Down Expand Up @@ -140,7 +141,7 @@ private void updateWeights(int i, double[] output_error_signal, List<double[]> h
}
}

public double evaluate(double[][] testInputs, double[][] testTargets) {
private double evaluate(double[][] testInputs, double[][] testTargets) {
int correctCount = 0;

for (int i = 0; i < testInputs.length; i++) {
Expand All @@ -165,4 +166,8 @@ public double evaluate(double[][] testInputs, double[][] testTargets) {
public double[] predict(double[] input) {
return feedforward(input);
}

public double getAccuracy() {
return bestAccuracy;
}
}
10 changes: 10 additions & 0 deletions lib/src/test/java/de/edux/data/provider/Penguin.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package de.edux.data.provider;

/**
* SeaBorn penguin dto
*/
public record Penguin(String species, String island, double billLengthMm, double billDepthMm, int flipperLengthMm, int bodyMassG, String sex){
public double[] getFeatures() {
return new double[]{billLengthMm, billDepthMm, flipperLengthMm, bodyMassG};
}
}
119 changes: 119 additions & 0 deletions lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package de.edux.data.provider;

import java.util.ArrayList;
import java.util.List;

public class SeabornDataProcessor extends DataUtil<Penguin> {
@Override
public void normalize(List<Penguin> penguins) {
double maxBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).max().orElse(1);
double minBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).min().orElse(0);

double maxBillDepth = penguins.stream().mapToDouble(Penguin::billDepthMm).max().orElse(1);
double minBillDepth = penguins.stream().mapToDouble(Penguin::billDepthMm).min().orElse(0);

double maxFlipperLength = penguins.stream().mapToInt(Penguin::flipperLengthMm).max().orElse(1);
double minFlipperLength = penguins.stream().mapToInt(Penguin::flipperLengthMm).min().orElse(0);

double maxBodyMass = penguins.stream().mapToInt(Penguin::bodyMassG).max().orElse(1);
double minBodyMass = penguins.stream().mapToInt(Penguin::bodyMassG).min().orElse(0);

List<Penguin> normalizedPenguins = new ArrayList<>();
for (Penguin p : penguins) {
double normalizedBillLength = (p.billLengthMm() - minBillLength) / (maxBillLength - minBillLength);
double normalizedBillDepth = (p.billDepthMm() - minBillDepth) / (maxBillDepth - minBillDepth);
double normalizedFlipperLength = (p.flipperLengthMm() - minFlipperLength) / (maxFlipperLength - minFlipperLength);
double normalizedBodyMass = (p.bodyMassG() - minBodyMass) / (maxBodyMass - minBodyMass);

p = new Penguin(p.species(), p.island(), normalizedBillLength, normalizedBillDepth, (int) normalizedFlipperLength, (int) normalizedBodyMass, p.sex());

Penguin normalizedPenguin = new Penguin(p.species(), p.island(), normalizedBillLength, normalizedBillDepth, (int) normalizedFlipperLength, (int) normalizedBodyMass, p.sex());
normalizedPenguins.add(normalizedPenguin);
}
penguins.clear();
penguins.addAll(normalizedPenguins);
}

@Override
public Penguin mapToDataRecord(String[] csvLine) {
if (csvLine.length != 7) {
throw new IllegalArgumentException("CSV line format is invalid. Expected 7 fields, got " + csvLine.length + ".");
}

for (String value : csvLine) {
if (value == null || value.trim().isEmpty()) {
return null;
}
}

String species = csvLine[0];
if (!(species.equalsIgnoreCase("adelie") || species.equalsIgnoreCase("chinstrap") || species.equalsIgnoreCase("gentoo"))) {
throw new IllegalArgumentException("Invalid species: " + species);
}

String island = csvLine[1];

double billLengthMm;
double billDepthMm;
int flipperLengthMm;
int bodyMassG;

try {
billLengthMm = Double.parseDouble(csvLine[2]);
if (billLengthMm < 0) {
throw new IllegalArgumentException("Bill length cannot be negative.");
}

billDepthMm = Double.parseDouble(csvLine[3]);
if (billDepthMm < 0) {
throw new IllegalArgumentException("Bill depth cannot be negative.");
}

flipperLengthMm = Integer.parseInt(csvLine[4]);
if (flipperLengthMm < 0) {
throw new IllegalArgumentException("Flipper length cannot be negative.");
}

bodyMassG = Integer.parseInt(csvLine[5]);
if (bodyMassG < 0) {
throw new IllegalArgumentException("Body mass cannot be negative.");
}
} catch (NumberFormatException e) {
throw new IllegalArgumentException("Invalid number format in CSV line", e);
}

String sex = csvLine[6];
if (!(sex.equalsIgnoreCase("male") || sex.equalsIgnoreCase("female"))) {
throw new IllegalArgumentException("Invalid sex: " + sex);
}

return new Penguin(species, island, billLengthMm, billDepthMm, flipperLengthMm, bodyMassG, sex);
}

@Override
public double[][] getInputs(List<Penguin> dataset) {
double[][] inputs = new double[dataset.size()][4];

for (int i = 0; i < dataset.size(); i++) {
Penguin p = dataset.get(i);
inputs[i][0] = p.billLengthMm();
inputs[i][1] = p.billDepthMm();
inputs[i][2] = p.flipperLengthMm();
inputs[i][3] = p.bodyMassG();
}

return inputs;
}

@Override
public double[][] getTargets(List<Penguin> dataset) {
double[][] targets = new double[dataset.size()][1];

for (int i = 0; i < dataset.size(); i++) {
Penguin p = dataset.get(i);
targets[i][0] = "Male".equalsIgnoreCase(p.sex()) ? 1.0 : 0.0;
}

return targets;
}
}
109 changes: 109 additions & 0 deletions lib/src/test/java/de/edux/data/provider/SeabornProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package de.edux.data.provider;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;

public class SeabornProvider implements IDataProvider<Penguin> {
private static final Logger LOG = LoggerFactory.getLogger(SeabornProvider.class);
private final List<Penguin> dataset;
private final List<Penguin> trainingData;
private final List<Penguin> testData;

public SeabornProvider(List<Penguin> dataset, List<Penguin> trainingData, List<Penguin> testData) {
this.dataset = dataset;
this.trainingData = trainingData;
this.testData = testData;
}

@Override
public List<Penguin> getTrainData() {
return trainingData;
}

@Override
public List<Penguin> getTestData() {
return testData;
}

@Override
public void printStatistics() {
LOG.info("========================= Data Statistic ==================");
LOG.info("Dataset: Seaborn Penguins");
LOG.info("Description: " + getDescription());
LOG.info("Total dataset size: " + dataset.size());
LOG.info("Training dataset size: " + trainingData.size());
LOG.info("Test data set size: " + testData.size());
LOG.info("Classes: " + getTrainLabels()[0].length);
LOG.info("===========================================================");
}

@Override
public Penguin getRandom(boolean equalDistribution) {
return dataset.get((int) (Math.random() * dataset.size()));
}

@Override
public double[][] getTrainFeatures() {
return featuresOf(trainingData);
}

private double[][] featuresOf(List<Penguin> data) {
double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften

for (int i = 0; i < data.size(); i++) {
Penguin p = data.get(i);
features[i][0] = p.billLengthMm();
features[i][1] = p.billDepthMm();
features[i][2] = p.flipperLengthMm();
features[i][3] = p.bodyMassG();
}

return features;
}


@Override
public double[][] getTrainLabels() {
return labelsOf(trainingData);
}

private double[][] labelsOf(List<Penguin> data) {
double[][] labels = new double[data.size()][3]; // 3 Pinguinarten

for (int i = 0; i < data.size(); i++) {
Penguin p = data.get(i);
switch (p.species().toLowerCase()) {
case "adelie":
labels[i] = new double[]{1.0, 0.0, 0.0};
break;
case "chinstrap":
labels[i] = new double[]{0.0, 1.0, 0.0};
break;
case "gentoo":
labels[i] = new double[]{0.0, 0.0, 1.0};
break;
default:
throw new IllegalArgumentException("Unbekannte Pinguinart: " + p.species());
}
}

return labels;
}

@Override
public double[][] getTestFeatures() {
return featuresOf(testData);
}

@Override
public double[][] getTestLabels() {
return labelsOf(testData);
}

@Override
public String getDescription() {
return "The Seaborn Penguin dataset comprises measurements and species classifications for penguins collected from three islands in the Palmer Archipelago, Antarctica.";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package de.edux.ml.nn.network;

import de.edux.data.provider.Penguin;
import de.edux.data.provider.SeabornDataProcessor;
import de.edux.data.provider.SeabornProvider;
import de.edux.functions.activation.ActivationFunction;
import de.edux.functions.initialization.Initialization;
import de.edux.functions.loss.LossFunction;
import de.edux.ml.nn.config.NetworkConfiguration;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.net.URL;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertTrue;

class MultilayerPerceptronTest {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean FILTER_INCOMPLETE_RECORDS = true;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.7;
private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv";
private SeabornProvider seabornProvider;

@BeforeEach
void setUp() {
URL url = MultilayerPerceptronTest.class.getClassLoader().getResource(CSV_FILE_PATH);
if (url == null) {
throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH);
}
File csvFile = new File(url.getPath());
var seabornDataProcessor = new SeabornDataProcessor();
var dataset = seabornDataProcessor.loadTDataSet(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS);
List<List<Penguin>> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO);
seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1));

}

@RepeatedTest(3)
void shouldReachModelAccuracyAtLeast70() {
double[][] features = seabornProvider.getTrainFeatures();
double[][] labels = seabornProvider.getTrainLabels();

double[][] testFeatures = seabornProvider.getTestFeatures();
double[][] testLabels = seabornProvider.getTestLabels();

assertTrue(features.length > 0);
assertTrue(labels.length > 0);
assertTrue(testFeatures.length > 0);
assertTrue(testLabels.length > 0);

NetworkConfiguration networkConfiguration = new NetworkConfiguration(features[0].length, List.of(24, 6), 3, 0.001, 10000, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER);

MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(features, labels, testFeatures, testLabels, networkConfiguration);
multilayerPerceptron.train();
double accuracy = multilayerPerceptron.getAccuracy();
assertTrue(accuracy > 0.7);

}
}
Loading

0 comments on commit 8e2fa88

Please sign in to comment.