diff --git a/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java b/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java index f070e80..9b03f7a 100644 --- a/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java +++ b/example/src/main/java/de/example/decisiontree/DecisionTreeExample.java @@ -20,23 +20,18 @@ public static void main(String[] args) { double[][] features = datasetProvider.getTrainFeatures(); double[][] labels = datasetProvider.getTrainLabels(); - // 1 - SATOSA 2 - VERSICOLOR 3 - VIRGINICA - int[] decisionTreeTrainLabels = convert2DLabelArrayTo1DLabelArray(labels); - // Train Decision Tree IDecisionTree decisionTree = new DecisionTree(); - decisionTree.train(features, decisionTreeTrainLabels, 6, 2, 1, 4); + decisionTree.train(features, labels, 6, 2, 1, 4); // Evaluate Decision Tree double[][] testFeatures = datasetProvider.getTestFeatures(); double[][] testLabels = datasetProvider.getTestLabels(); - int[] decisionTreeTestLabels = convert2DLabelArrayTo1DLabelArray(testLabels); - decisionTree.evaluate(testFeatures, decisionTreeTestLabels); + decisionTree.evaluate(testFeatures, testLabels); // Get Feature Importance double[] featureImportance = decisionTree.getFeatureImportance(); System.out.println("Feature Importance: " + Arrays.toString(featureImportance)); } - } diff --git a/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java b/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java index 4f56ec2..0a08291 100644 --- a/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java +++ b/lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java @@ -1,11 +1,13 @@ package de.edux.ml.decisiontree; -import java.util.*; -import java.util.function.Function; -import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Arrays; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + /** * A decision tree classifier. *

@@ -13,24 +15,7 @@ * The decision tree is built by recursively splitting the training data based on * the feature that results in the minimum Gini index, which is a measure of impurity. *

- * - *

- * Once the decision tree is built, new instances can be classified by traversing the tree - * from the root to a leaf node. The class of the leaf node is then assigned to the instance. - *

- * - *

- * The decision tree algorithm implemented here includes several stopping conditions to avoid - * overfitting, including a maximum depth, a minimum number of samples per leaf, and a minimum - * number of samples to allow a split. - *

- * - *

- * The decision tree can be used for multiclass classification problems. For binary classification, - * the output is either 0 or 1. For multiclass classification, the output is the class with the - * maximum frequency in the leaf node. - *

- */ +*/ public class DecisionTree implements IDecisionTree { private static final Logger LOG = LoggerFactory.getLogger(DecisionTree.class); private Node root; @@ -39,8 +24,6 @@ public class DecisionTree implements IDecisionTree { private int minSamplesLeaf; private int maxLeafNodes; - - private double calculateGiniIndex(double[] labels) { if (labels.length == 0) { return 0.0; @@ -120,7 +103,7 @@ private void buildTree(Node node) { @Override public void train( double[][] features, - int[] labels, + double[][] labels, int maxDepth, int minSamplesSplit, int minSamplesLeaf, @@ -133,14 +116,31 @@ public void train( double[][] data = new double[features.length][]; for (int i = 0; i < features.length; i++) { data[i] = Arrays.copyOf(features[i], features[i].length + 1); - data[i][data[i].length - 1] = labels[i]; + data[i][data[i].length - 1] = getIndexOfHighestValue(labels[i]); } root = new Node(data); buildTree(root); } + private double getIndexOfHighestValue(double[] labels) { + if (labels == null || labels.length == 0) { + throw new IllegalArgumentException("Array must not be null or empty"); + } + + int maxIndex = 0; + double maxValue = labels[0]; + + for (int i = 1; i < labels.length; i++) { + if (labels[i] > maxValue) { + maxValue = labels[i]; + maxIndex = i; + } + } + + return maxIndex; + } + @Override - // Add to the DecisionTree class public double predict(double[] feature) { return predict(feature, root); } @@ -172,18 +172,27 @@ private double getMostCommonLabel(double[][] data) { } @Override - public double evaluate(double[][] features, int[] labels) { + public double evaluate(double[][] features, double[][] labels) { int correctPredictions = 0; for (int i = 0; i < features.length; i++) { - if (predict(features[i]) == labels[i]) { + double predictedLabel = predict(features[i]); + double actualLabel = getIndexOfHighestValue(labels[i]); + + if (predictedLabel == actualLabel) { correctPredictions++; } } + + // Calculate accuracy: ratio of correct predictions to total predictions double accuracy = (double) correctPredictions / features.length; - LOG.info("Accuracy: " + String.format("%.4f", accuracy * 100) + "%"); + + // Log the accuracy value (optional) + LOG.info("Model Accuracy: {}%", accuracy * 100); + return accuracy; } + @Override public double[] getFeatureImportance() { int numFeatures = root.data[0].length - 1; @@ -236,4 +245,4 @@ private int getLeafCount(Node node) { return getLeafCount(node.left) + getLeafCount(node.right); } } -} +} \ No newline at end of file diff --git a/lib/src/main/java/de/edux/ml/decisiontree/IDecisionTree.java b/lib/src/main/java/de/edux/ml/decisiontree/IDecisionTree.java index 5226798..69b4cc7 100644 --- a/lib/src/main/java/de/edux/ml/decisiontree/IDecisionTree.java +++ b/lib/src/main/java/de/edux/ml/decisiontree/IDecisionTree.java @@ -4,7 +4,7 @@ public interface IDecisionTree { void train( double[][] features, - int[] labels, + double[][] labels, int maxDepth, int minSamplesSplit, int minSamplesLeaf, @@ -24,7 +24,7 @@ void train( * @param labels the labels to evaluate * @return true if the decision tree correctly classified the features and labels, false otherwise */ - double evaluate(double[][] features, int[] labels); + double evaluate(double[][] features, double[][] labels); /** * Returns the feature importance of the decision tree. diff --git a/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java b/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java new file mode 100644 index 0000000..25b1d4e --- /dev/null +++ b/lib/src/test/java/de/edux/ml/decisiontree/DecisionTreeTest.java @@ -0,0 +1,66 @@ +package de.edux.ml.decisiontree; + +import de.edux.data.provider.Penguin; +import de.edux.data.provider.SeabornDataProcessor; +import de.edux.data.provider.SeabornProvider; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.net.URL; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +class DecisionTreeTest { + private static final boolean SHUFFLE = true; + private static final boolean NORMALIZE = true; + private static final boolean FILTER_INCOMPLETE_RECORDS = true; + private static final double TRAIN_TEST_SPLIT_RATIO = 0.7; + private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv"; + private static SeabornProvider seabornProvider; + @BeforeAll + static void setup() { + URL url = DecisionTreeTest.class.getClassLoader().getResource(CSV_FILE_PATH); + if (url == null) { + throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH); + } + File csvFile = new File(url.getPath()); + var seabornDataProcessor = new SeabornDataProcessor(); + var dataset = seabornDataProcessor.loadTDataSet(csvFile, ',', SHUFFLE, NORMALIZE, FILTER_INCOMPLETE_RECORDS); + List> trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO); + seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.get(0), trainTestSplittedList.get(1)); + } + + @RepeatedTest(5) + void train() { + double[][] features = seabornProvider.getTrainFeatures(); + double[][] labels = seabornProvider.getTrainLabels(); + + double[][] testFeatures = seabornProvider.getTestFeatures(); + double[][] testLabels = seabornProvider.getTestLabels(); + + assertTrue(features.length > 0); + assertTrue(labels.length > 0); + assertTrue(testFeatures.length > 0); + assertTrue(testLabels.length > 0); + + IDecisionTree decisionTree = new DecisionTree(); + decisionTree.train(features, labels, 10, 2, 1, 8); + double accuracy = decisionTree.evaluate(testFeatures, testLabels); + assertTrue(accuracy>0.7); + } + + @Test + void predict() { + } + + @Test + void evaluate() { + } + + @Test + void getFeatureImportance() { + } +} \ No newline at end of file