Skip to content

Commit

Permalink
Merge pull request #71 from Samyssmile/decisiontreefix
Browse files Browse the repository at this point in the history
fix(#70): correct leaf node counting in DecisionTree implementation
  • Loading branch information
Samyssmile authored Oct 28, 2023
2 parents fc5c7ba + 02250e1 commit 7ae2216
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 36 deletions.
12 changes: 3 additions & 9 deletions example/src/main/java/de/example/benchmark/Benchmark.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ private void run() {
initFeaturesAndLabels();

Classifier knn = new KnnClassifier(2);
Classifier decisionTree = new DecisionTree(8, 2, 1, 3);
Classifier randomForest = new RandomForest(100, 10, 2, 1, 3, 60);
Classifier decisionTree = new DecisionTree(2, 2, 3, 12);
Classifier randomForest = new RandomForest(500, 10, 2, 3, 3, 60);
Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 1);

networkConfiguration = new NetworkConfiguration(trainFeatures[0].length, List.of(128, 256, 512), 3, 0.01, 300, ActivationFunction.LEAKY_RELU, ActivationFunction.SOFTMAX, LossFunction.CATEGORICAL_CROSS_ENTROPY, Initialization.XAVIER, Initialization.XAVIER);
Expand All @@ -66,7 +66,7 @@ private void run() {
results.put("MLP", new ArrayList<>());


IntStream.range(0, 1).forEach(i -> {
IntStream.range(0, 5).forEach(i -> {
knn.train(trainFeatures, trainLabels);
decisionTree.train(trainFeatures, trainLabels);
randomForest.train(trainFeatures, trainLabels);
Expand All @@ -88,7 +88,6 @@ private void run() {
updateMLP(testFeatures, testLabels);
});


System.out.println("Classifier performances (sorted by average accuracy):");
results.entrySet().stream()
.map(entry -> {
Expand All @@ -103,7 +102,6 @@ private void run() {
System.out.printf("%s: %.2f%%\n", entry.getKey(), entry.getValue() * 100);
});

// Additionally, if you want to show other metrics, such as minimum or maximum accuracy, you can calculate and display them similarly.
System.out.println("\nClassifier best and worst performances:");
results.forEach((classifierName, accuracies) -> {
double maxAccuracy = accuracies.stream()
Expand All @@ -116,8 +114,6 @@ private void run() {
.orElse(0.0);
System.out.printf("%s: Best: %.2f%%, Worst: %.2f%%\n", classifierName, maxAccuracy * 100, minAccuracy * 100);
});


}

private void updateMLP(double[][] testFeatures, double[][] testLabels) {
Expand All @@ -134,8 +130,6 @@ private void initFeaturesAndLabels() {
.shuffle()
.split(TRAIN_TEST_SPLIT_RATIO);



trainFeatures = dataProcessor.getTrainFeatures(featureColumnIndices);
trainLabels = dataProcessor.getTrainLabels(targetColumnIndex);
testFeatures = dataProcessor.getTestFeatures(featureColumnIndices);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package de.example.decisiontree;

import de.edux.api.Classifier;
import de.edux.data.provider.DataProcessor;
import de.edux.data.reader.CSVIDataReader;
import de.edux.ml.decisiontree.DecisionTree;
import de.edux.ml.randomforest.RandomForest;

import java.io.File;

public class DecisionTreeExampleOnIrisDataset {
private static final double TRAIN_TEST_SPLIT_RATIO = 0.70;
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv");
private static final boolean SKIP_HEAD = true;
public static void main(String[] args) {
/* IRIS Dataset...
+-------------+------------+-------------+------------+---------+
| sepal.length| sepal.width| petal.length| petal.width| variety |
+-------------+------------+-------------+------------+---------+
| 5.1 | 3.5 | 1.4 | .2 | Setosa |
+-------------+------------+-------------+------------+---------+
*/
var featureColumnIndices = new int[]{0, 1, 2, 3}; // First 4 columns are features
var targetColumnIndex = 4; // Last column is the target

var irisDataProcessor = new DataProcessor(new CSVIDataReader()).loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD, featureColumnIndices, targetColumnIndex).normalize().shuffle().split(TRAIN_TEST_SPLIT_RATIO);
Classifier classifier = new DecisionTree(2, 2, 3, 12);

var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices);
var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices);
var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex);
var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex);

classifier.train(trainFeatures, trainLabels);
classifier.evaluate(trainTestFeatures, trainTestLabels);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import java.io.File;
import java.util.List;

public class MultilayerPerceptronExampleV2 {
public class MultilayerNeuralNetworkExampleOnIrisDataset {

private static final double TRAIN_TEST_SPLIT_RATIO = 0.70;
private static final File CSV_FILE = new File("example" + File.separator + "datasets" + File.separator + "iris" + File.separator + "iris.csv");
Expand Down
86 changes: 60 additions & 26 deletions lib/src/main/java/de/edux/ml/decisiontree/DecisionTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ public class DecisionTree implements Classifier {
private final int minSamplesSplit;
private final int minSamplesLeaf;
private final int maxLeafNodes;
private int currentLeafNodes;

private final Map<Integer, Double> featureImportances;

public DecisionTree(int maxDepth,
Expand All @@ -61,6 +63,7 @@ public DecisionTree(int maxDepth,
this.minSamplesSplit = minSamplesSplit;
this.minSamplesLeaf = minSamplesLeaf;
this.maxLeafNodes = maxLeafNodes;
this.currentLeafNodes = 0;
this.featureImportances = new HashMap<>();
}

Expand All @@ -85,18 +88,35 @@ private Node buildTree(double[][] features, double[][] labels, int depth) {
Node node = new Node(features);
node.predictedLabel = getMajorityLabel(labels);

if (shouldTerminate(features, depth)) {
currentLeafNodes++;
return node;
}

SplitResult bestSplit = findBestSplit(features, labels);
if (bestSplit != null) {
applyBestSplit(node, bestSplit, features, labels, depth);
} else {
currentLeafNodes++;
}

return node;
}

private boolean shouldTerminate(double[][] features, int depth) {
boolean maxDepthReached = depth >= maxDepth;
boolean tooFewSamples = features.length < minSamplesSplit;
int currentLeafNodes = 0;
boolean maxLeafNodesReached = currentLeafNodes >= maxLeafNodes;

if (maxDepthReached || tooFewSamples || maxLeafNodesReached) {
return node;
return true;
}
return false;
}

private SplitResult findBestSplit(double[][] features, double[][] labels) {
double bestGini = Double.MAX_VALUE;
double[][] bestLeftFeatures = null, bestRightFeatures = null;
double[][] bestLeftLabels = null, bestRightLabels = null;
SplitResult bestSplit = null;

for (int featureIndex = 0; featureIndex < features[0].length; featureIndex++) {
for (double[] feature : features) {
Expand All @@ -111,29 +131,31 @@ private Node buildTree(double[][] features, double[][] labels, int depth) {
if (gini < bestGini) {
bestGini = gini;
updateFeatureImportances(featureIndex, gini);
bestLeftFeatures = leftFeatures;
bestRightFeatures = rightFeatures;
bestLeftLabels = leftLabels;
bestRightLabels = rightLabels;
node.splitFeatureIndex = featureIndex;
node.splitValue = feature[featureIndex];
bestSplit = new SplitResult(featureIndex, feature[featureIndex], leftFeatures, rightFeatures, leftLabels, rightLabels);
}
}
}

if (bestLeftFeatures != null && bestRightFeatures != null &&
bestLeftFeatures.length >= minSamplesLeaf && bestRightFeatures.length >= minSamplesLeaf) {
return bestSplit;
}

Node leftChild = buildTree(bestLeftFeatures, bestLeftLabels, depth + 1);
Node rightChild = buildTree(bestRightFeatures, bestRightLabels, depth + 1);
private void applyBestSplit(Node node, SplitResult bestSplit, double[][] features, double[][] labels, int depth) {
node.splitFeatureIndex = bestSplit.featureIndex;
node.splitValue = bestSplit.splitValue;

if(currentLeafNodes + 2 <= maxLeafNodes) {
node.left = leftChild;
node.right = rightChild;
if (bestSplit.bestLeftFeatures != null && bestSplit.bestRightFeatures != null &&
bestSplit.bestLeftFeatures.length >= minSamplesLeaf && bestSplit.bestRightFeatures.length >= minSamplesLeaf) {

if (currentLeafNodes + 2 <= maxLeafNodes) {
node.left = buildTree(bestSplit.bestLeftFeatures, bestSplit.bestLeftLabels, depth + 1);
node.right = buildTree(bestSplit.bestRightFeatures, bestSplit.bestRightLabels, depth + 1);
currentLeafNodes += 2;
} else {
currentLeafNodes++;
}
} else {
currentLeafNodes++;
}

return node;
}

private void updateFeatureImportances(int featureIndex, double giniReduction) {
Expand All @@ -148,7 +170,6 @@ public Map<Integer, Double> getFeatureImportances() {
e -> e.getValue() / totalImportance));
}


private double[][] filterRows(double[][] matrix, int featureIndex, double value, boolean lessThan) {
return Arrays.stream(matrix)
.filter(row -> (lessThan && row[featureIndex] < value) || (!lessThan && row[featureIndex] >= value))
Expand Down Expand Up @@ -232,8 +253,6 @@ private double[] predictRecursive(Node node, double[] feature) {
}
}



private double computeGini(double[][] leftLabels, double[][] rightLabels) {
double leftImpurity = computeImpurity(leftLabels);
double rightImpurity = computeImpurity(rightLabels);
Expand All @@ -242,7 +261,6 @@ private double computeGini(double[][] leftLabels, double[][] rightLabels) {
return leftWeight * leftImpurity + rightWeight * rightImpurity;
}


private double computeImpurity(double[][] labels) {
double impurity = 1.0;
Map<String, Long> labelCounts = Arrays.stream(labels)
Expand All @@ -255,9 +273,25 @@ private double computeImpurity(double[][] labels) {
return impurity;
}



static class Node {
private static class SplitResult {
int featureIndex;
double splitValue;
double[][] bestLeftFeatures;
double[][] bestRightFeatures;
double[][] bestLeftLabels;
double[][] bestRightLabels;

SplitResult(int featureIndex, double splitValue, double[][] bestLeftFeatures, double[][] bestRightFeatures,
double[][] bestLeftLabels, double[][] bestRightLabels) {
this.featureIndex = featureIndex;
this.splitValue = splitValue;
this.bestLeftFeatures = bestLeftFeatures;
this.bestRightFeatures = bestRightFeatures;
this.bestLeftLabels = bestLeftLabels;
this.bestRightLabels = bestRightLabels;
}
}
private static class Node {
double[][] data;
Node left;
Node right;
Expand Down

0 comments on commit 7ae2216

Please sign in to comment.