Skip to content

Commit

Permalink
Write jUnit Test for decision trees #28
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Abramov committed Oct 6, 2023
1 parent 85927ca commit 787786c
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,18 @@ public static void main(String[] args) {
double[][] features = datasetProvider.getTrainFeatures();
double[][] labels = datasetProvider.getTrainLabels();

// 1 - SATOSA 2 - VERSICOLOR 3 - VIRGINICA
int[] decisionTreeTrainLabels = convert2DLabelArrayTo1DLabelArray(labels);

// Train Decision Tree
IDecisionTree decisionTree = new DecisionTree();
decisionTree.train(features, decisionTreeTrainLabels, 6, 2, 1, 4);
decisionTree.train(features, labels, 6, 2, 1, 4);

// Evaluate Decision Tree
double[][] testFeatures = datasetProvider.getTestFeatures();
double[][] testLabels = datasetProvider.getTestLabels();
int[] decisionTreeTestLabels = convert2DLabelArrayTo1DLabelArray(testLabels);
decisionTree.evaluate(testFeatures, decisionTreeTestLabels);
decisionTree.evaluate(testFeatures, testLabels);

// Get Feature Importance
double[] featureImportance = decisionTree.getFeatureImportance();
System.out.println("Feature Importance: " + Arrays.toString(featureImportance));
}


}
69 changes: 39 additions & 30 deletions lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java
Original file line number Diff line number Diff line change
@@ -1,36 +1,21 @@
package de.edux.ml.decisiontree;

import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* A decision tree classifier.
* <p>
* This class implements a binary decision tree algorithm for classification.
* The decision tree is built by recursively splitting the training data based on
* the feature that results in the minimum Gini index, which is a measure of impurity.
* </p>
*
* <p>
* Once the decision tree is built, new instances can be classified by traversing the tree
* from the root to a leaf node. The class of the leaf node is then assigned to the instance.
* </p>
*
* <p>
* The decision tree algorithm implemented here includes several stopping conditions to avoid
* overfitting, including a maximum depth, a minimum number of samples per leaf, and a minimum
* number of samples to allow a split.
* </p>
*
* <p>
* The decision tree can be used for multiclass classification problems. For binary classification,
* the output is either 0 or 1. For multiclass classification, the output is the class with the
* maximum frequency in the leaf node.
* </p>
*/
*/
public class DecisionTree implements IDecisionTree {
private static final Logger LOG = LoggerFactory.getLogger(DecisionTree.class);
private Node root;
Expand All @@ -39,8 +24,6 @@ public class DecisionTree implements IDecisionTree {
private int minSamplesLeaf;
private int maxLeafNodes;



private double calculateGiniIndex(double[] labels) {
if (labels.length == 0) {
return 0.0;
Expand Down Expand Up @@ -120,7 +103,7 @@ private void buildTree(Node node) {
@Override
public void train(
double[][] features,
int[] labels,
double[][] labels,
int maxDepth,
int minSamplesSplit,
int minSamplesLeaf,
Expand All @@ -133,14 +116,31 @@ public void train(
double[][] data = new double[features.length][];
for (int i = 0; i < features.length; i++) {
data[i] = Arrays.copyOf(features[i], features[i].length + 1);
data[i][data[i].length - 1] = labels[i];
data[i][data[i].length - 1] = getIndexOfHighestValue(labels[i]);
}
root = new Node(data);
buildTree(root);
}

private double getIndexOfHighestValue(double[] labels) {
if (labels == null || labels.length == 0) {
throw new IllegalArgumentException("Array must not be null or empty");
}

int maxIndex = 0;
double maxValue = labels[0];

for (int i = 1; i < labels.length; i++) {
if (labels[i] > maxValue) {
maxValue = labels[i];
maxIndex = i;
}
}

return maxIndex;
}

@Override
// Add to the DecisionTree class
public double predict(double[] feature) {
return predict(feature, root);
}
Expand Down Expand Up @@ -172,18 +172,27 @@ private double getMostCommonLabel(double[][] data) {
}

@Override
public double evaluate(double[][] features, int[] labels) {
public double evaluate(double[][] features, double[][] labels) {
int correctPredictions = 0;
for (int i = 0; i < features.length; i++) {
if (predict(features[i]) == labels[i]) {
double predictedLabel = predict(features[i]);
double actualLabel = getIndexOfHighestValue(labels[i]);

if (predictedLabel == actualLabel) {
correctPredictions++;
}
}

// Calculate accuracy: ratio of correct predictions to total predictions
double accuracy = (double) correctPredictions / features.length;
LOG.info("Accuracy: " + String.format("%.4f", accuracy * 100) + "%");

// Log the accuracy value (optional)
LOG.info("Model Accuracy: {}%", accuracy * 100);

return accuracy;
}


@Override
public double[] getFeatureImportance() {
int numFeatures = root.data[0].length - 1;
Expand Down Expand Up @@ -236,4 +245,4 @@ private int getLeafCount(Node node) {
return getLeafCount(node.left) + getLeafCount(node.right);
}
}
}
}
4 changes: 2 additions & 2 deletions lib/src/main/java/de/edux/ml/decisiontree/IDecisionTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ public interface IDecisionTree {

void train(
double[][] features,
int[] labels,
double[][] labels,
int maxDepth,
int minSamplesSplit,
int minSamplesLeaf,
Expand All @@ -24,7 +24,7 @@ void train(
* @param labels the labels to evaluate
* @return true if the decision tree correctly classified the features and labels, false otherwise
*/
double evaluate(double[][] features, int[] labels);
double evaluate(double[][] features, double[][] labels);

/**
* Returns the feature importance of the decision tree.
Expand Down
66 changes: 66 additions & 0 deletions lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package de.edux.ml.decisiontree;

import de.edux.data.provider.Penguin;
import de.edux.data.provider.SeabornDataProcessor;
import de.edux.data.provider.SeabornProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.net.URL;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertTrue;

class DecisionTreeTest {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean FILTER_INCOMPLETE_RECORDS = true;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.7;
private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv";
private static SeabornProvider seabornProvider;
@BeforeAll
static void setup() {
URL url = DecisionTreeTest.class.getClassLoader().getResource(CSV_FILE_PATH);
if (url == null) {
throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH);
}
File csvFile = new File(url.getPath());
var seabornDataProcessor = new SeabornDataProcessor();
var dataset = seabornDataProcessor.loadTDataSet(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS);
List<List<Penguin>> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO);
seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1));
}

@RepeatedTest(5)
void train() {
double[][] features = seabornProvider.getTrainFeatures();
double[][] labels = seabornProvider.getTrainLabels();

double[][] testFeatures = seabornProvider.getTestFeatures();
double[][] testLabels = seabornProvider.getTestLabels();

assertTrue(features.length > 0);
assertTrue(labels.length > 0);
assertTrue(testFeatures.length > 0);
assertTrue(testLabels.length > 0);

IDecisionTree decisionTree = new DecisionTree();
decisionTree.train(features, labels, 10, 2, 1, 8);
double accuracy = decisionTree.evaluate(testFeatures, testLabels);
assertTrue(accuracy>0.7);
}

@Test
void predict() {
}

@Test
void evaluate() {
}

@Test
void getFeatureImportance() {
}
}

0 comments on commit 787786c

Please sign in to comment.