-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #27 from Samyssmile/test/multilayerPerceptronTest
test(26): Write jUnit Tests for Multilayer NeuralNetwork #26
- Loading branch information
Showing
9 changed files
with
654 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package de.edux.data.provider; | ||
|
||
/** | ||
* SeaBorn penguin dto | ||
*/ | ||
public record Penguin(String species, String island, double billLengthMm, double billDepthMm, int flipperLengthMm, int bodyMassG, String sex){ | ||
public double[] getFeatures() { | ||
return new double[]{billLengthMm, billDepthMm, flipperLengthMm, bodyMassG}; | ||
} | ||
} |
119 changes: 119 additions & 0 deletions
119
lib/src/test/java/de/edux/data/provider/SeabornDataProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package de.edux.data.provider; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
public class SeabornDataProcessor extends DataUtil<Penguin> { | ||
@Override | ||
public void normalize(List<Penguin> penguins) { | ||
double maxBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).max().orElse(1); | ||
double minBillLength = penguins.stream().mapToDouble(Penguin::billLengthMm).min().orElse(0); | ||
|
||
double maxBillDepth = penguins.stream().mapToDouble(Penguin::billDepthMm).max().orElse(1); | ||
double minBillDepth = penguins.stream().mapToDouble(Penguin::billDepthMm).min().orElse(0); | ||
|
||
double maxFlipperLength = penguins.stream().mapToInt(Penguin::flipperLengthMm).max().orElse(1); | ||
double minFlipperLength = penguins.stream().mapToInt(Penguin::flipperLengthMm).min().orElse(0); | ||
|
||
double maxBodyMass = penguins.stream().mapToInt(Penguin::bodyMassG).max().orElse(1); | ||
double minBodyMass = penguins.stream().mapToInt(Penguin::bodyMassG).min().orElse(0); | ||
|
||
List<Penguin> normalizedPenguins = new ArrayList<>(); | ||
for (Penguin p : penguins) { | ||
double normalizedBillLength = (p.billLengthMm() - minBillLength) / (maxBillLength - minBillLength); | ||
double normalizedBillDepth = (p.billDepthMm() - minBillDepth) / (maxBillDepth - minBillDepth); | ||
double normalizedFlipperLength = (p.flipperLengthMm() - minFlipperLength) / (maxFlipperLength - minFlipperLength); | ||
double normalizedBodyMass = (p.bodyMassG() - minBodyMass) / (maxBodyMass - minBodyMass); | ||
|
||
p = new Penguin(p.species(), p.island(), normalizedBillLength, normalizedBillDepth, (int) normalizedFlipperLength, (int) normalizedBodyMass, p.sex()); | ||
|
||
Penguin normalizedPenguin = new Penguin(p.species(), p.island(), normalizedBillLength, normalizedBillDepth, (int) normalizedFlipperLength, (int) normalizedBodyMass, p.sex()); | ||
normalizedPenguins.add(normalizedPenguin); | ||
} | ||
penguins.clear(); | ||
penguins.addAll(normalizedPenguins); | ||
} | ||
|
||
@Override | ||
public Penguin mapToDataRecord(String[] csvLine) { | ||
if (csvLine.length != 7) { | ||
throw new IllegalArgumentException("CSV line format is invalid. Expected 7 fields, got " + csvLine.length + "."); | ||
} | ||
|
||
for (String value : csvLine) { | ||
if (value == null || value.trim().isEmpty()) { | ||
return null; | ||
} | ||
} | ||
|
||
String species = csvLine[0]; | ||
if (!(species.equalsIgnoreCase("adelie") || species.equalsIgnoreCase("chinstrap") || species.equalsIgnoreCase("gentoo"))) { | ||
throw new IllegalArgumentException("Invalid species: " + species); | ||
} | ||
|
||
String island = csvLine[1]; | ||
|
||
double billLengthMm; | ||
double billDepthMm; | ||
int flipperLengthMm; | ||
int bodyMassG; | ||
|
||
try { | ||
billLengthMm = Double.parseDouble(csvLine[2]); | ||
if (billLengthMm < 0) { | ||
throw new IllegalArgumentException("Bill length cannot be negative."); | ||
} | ||
|
||
billDepthMm = Double.parseDouble(csvLine[3]); | ||
if (billDepthMm < 0) { | ||
throw new IllegalArgumentException("Bill depth cannot be negative."); | ||
} | ||
|
||
flipperLengthMm = Integer.parseInt(csvLine[4]); | ||
if (flipperLengthMm < 0) { | ||
throw new IllegalArgumentException("Flipper length cannot be negative."); | ||
} | ||
|
||
bodyMassG = Integer.parseInt(csvLine[5]); | ||
if (bodyMassG < 0) { | ||
throw new IllegalArgumentException("Body mass cannot be negative."); | ||
} | ||
} catch (NumberFormatException e) { | ||
throw new IllegalArgumentException("Invalid number format in CSV line", e); | ||
} | ||
|
||
String sex = csvLine[6]; | ||
if (!(sex.equalsIgnoreCase("male") || sex.equalsIgnoreCase("female"))) { | ||
throw new IllegalArgumentException("Invalid sex: " + sex); | ||
} | ||
|
||
return new Penguin(species, island, billLengthMm, billDepthMm, flipperLengthMm, bodyMassG, sex); | ||
} | ||
|
||
@Override | ||
public double[][] getInputs(List<Penguin> dataset) { | ||
double[][] inputs = new double[dataset.size()][4]; | ||
|
||
for (int i = 0; i < dataset.size(); i++) { | ||
Penguin p = dataset.get(i); | ||
inputs[i][0] = p.billLengthMm(); | ||
inputs[i][1] = p.billDepthMm(); | ||
inputs[i][2] = p.flipperLengthMm(); | ||
inputs[i][3] = p.bodyMassG(); | ||
} | ||
|
||
return inputs; | ||
} | ||
|
||
@Override | ||
public double[][] getTargets(List<Penguin> dataset) { | ||
double[][] targets = new double[dataset.size()][1]; | ||
|
||
for (int i = 0; i < dataset.size(); i++) { | ||
Penguin p = dataset.get(i); | ||
targets[i][0] = "Male".equalsIgnoreCase(p.sex()) ? 1.0 : 0.0; | ||
} | ||
|
||
return targets; | ||
} | ||
} |
109 changes: 109 additions & 0 deletions
109
lib/src/test/java/de/edux/data/provider/SeabornProvider.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
package de.edux.data.provider; | ||
|
||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.util.List; | ||
|
||
public class SeabornProvider implements IDataProvider<Penguin> { | ||
private static final Logger LOG = LoggerFactory.getLogger(SeabornProvider.class); | ||
private final List<Penguin> dataset; | ||
private final List<Penguin> trainingData; | ||
private final List<Penguin> testData; | ||
|
||
public SeabornProvider(List<Penguin> dataset, List<Penguin> trainingData, List<Penguin> testData) { | ||
this.dataset = dataset; | ||
this.trainingData = trainingData; | ||
this.testData = testData; | ||
} | ||
|
||
@Override | ||
public List<Penguin> getTrainData() { | ||
return trainingData; | ||
} | ||
|
||
@Override | ||
public List<Penguin> getTestData() { | ||
return testData; | ||
} | ||
|
||
@Override | ||
public void printStatistics() { | ||
LOG.info("========================= Data Statistic =================="); | ||
LOG.info("Dataset: Seaborn Penguins"); | ||
LOG.info("Description: " + getDescription()); | ||
LOG.info("Total dataset size: " + dataset.size()); | ||
LOG.info("Training dataset size: " + trainingData.size()); | ||
LOG.info("Test data set size: " + testData.size()); | ||
LOG.info("Classes: " + getTrainLabels()[0].length); | ||
LOG.info("==========================================================="); | ||
} | ||
|
||
@Override | ||
public Penguin getRandom(boolean equalDistribution) { | ||
return dataset.get((int) (Math.random() * dataset.size())); | ||
} | ||
|
||
@Override | ||
public double[][] getTrainFeatures() { | ||
return featuresOf(trainingData); | ||
} | ||
|
||
private double[][] featuresOf(List<Penguin> data) { | ||
double[][] features = new double[data.size()][4]; // 4 numerische Eigenschaften | ||
|
||
for (int i = 0; i < data.size(); i++) { | ||
Penguin p = data.get(i); | ||
features[i][0] = p.billLengthMm(); | ||
features[i][1] = p.billDepthMm(); | ||
features[i][2] = p.flipperLengthMm(); | ||
features[i][3] = p.bodyMassG(); | ||
} | ||
|
||
return features; | ||
} | ||
|
||
|
||
@Override | ||
public double[][] getTrainLabels() { | ||
return labelsOf(trainingData); | ||
} | ||
|
||
private double[][] labelsOf(List<Penguin> data) { | ||
double[][] labels = new double[data.size()][3]; // 3 Pinguinarten | ||
|
||
for (int i = 0; i < data.size(); i++) { | ||
Penguin p = data.get(i); | ||
switch (p.species().toLowerCase()) { | ||
case "adelie": | ||
labels[i] = new double[]{1.0, 0.0, 0.0}; | ||
break; | ||
case "chinstrap": | ||
labels[i] = new double[]{0.0, 1.0, 0.0}; | ||
break; | ||
case "gentoo": | ||
labels[i] = new double[]{0.0, 0.0, 1.0}; | ||
break; | ||
default: | ||
throw new IllegalArgumentException("Unbekannte Pinguinart: " + p.species()); | ||
} | ||
} | ||
|
||
return labels; | ||
} | ||
|
||
@Override | ||
public double[][] getTestFeatures() { | ||
return featuresOf(testData); | ||
} | ||
|
||
@Override | ||
public double[][] getTestLabels() { | ||
return labelsOf(testData); | ||
} | ||
|
||
@Override | ||
public String getDescription() { | ||
return "The Seaborn Penguin dataset comprises measurements and species classifications for penguins collected from three islands in the Palmer Archipelago, Antarctica."; | ||
} | ||
} |
63 changes: 63 additions & 0 deletions
63
lib/src/test/java/de/edux/ml/nn/network/MultilayerPerceptronTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
package de.edux.ml.nn.network; | ||
|
||
import de.edux.data.provider.Penguin; | ||
import de.edux.data.provider.SeabornDataProcessor; | ||
import de.edux.data.provider.SeabornProvider; | ||
import de.edux.functions.activation.ActivationFunction; | ||
import de.edux.functions.initialization.Initialization; | ||
import de.edux.functions.loss.LossFunction; | ||
import de.edux.ml.nn.config.NetworkConfiguration; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.RepeatedTest; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.io.File; | ||
import java.net.URL; | ||
import java.util.List; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertTrue; | ||
|
||
class MultilayerPerceptronTest { | ||
private static final boolean SHUFFLE = true; | ||
private static final boolean NORMALIZE = true; | ||
private static final boolean FILTER_INCOMPLETE_RECORDS = true; | ||
private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; | ||
private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; | ||
private SeabornProvider seabornProvider; | ||
|
||
@BeforeEach | ||
void setUp() { | ||
URL url = MultilayerPerceptronTest.class.getClassLoader().getResource(CSV_FILE_PATH); | ||
if (url == null) { | ||
throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH); | ||
} | ||
File csvFile = new File(url.getPath()); | ||
var seabornDataProcessor = new SeabornDataProcessor(); | ||
var dataset = seabornDataProcessor.loadTDataSet(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); | ||
List<List<Penguin>> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); | ||
seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1)); | ||
|
||
} | ||
|
||
@RepeatedTest(3) | ||
void shouldReachModelAccuracyAtLeast70() { | ||
double[][] features = seabornProvider.getTrainFeatures(); | ||
double[][] labels = seabornProvider.getTrainLabels(); | ||
|
||
double[][] testFeatures = seabornProvider.getTestFeatures(); | ||
double[][] testLabels = seabornProvider.getTestLabels(); | ||
|
||
assertTrue(features.length > 0); | ||
assertTrue(labels.length > 0); | ||
assertTrue(testFeatures.length > 0); | ||
assertTrue(testLabels.length > 0); | ||
|
||
NetworkConfiguration networkConfiguration = new NetworkConfiguration(features[0].length, List.of(24, 6), 3, 0.001, 10000, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER); | ||
|
||
MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(features, labels, testFeatures, testLabels, networkConfiguration); | ||
multilayerPerceptron.train(); | ||
double accuracy = multilayerPerceptron.getAccuracy(); | ||
assertTrue(accuracy > 0.7); | ||
|
||
} | ||
} |
Oops, something went wrong.