Skip to content

Commit

Permalink
Merge pull request #39 from Samyssmile/chore/docs
Browse files Browse the repository at this point in the history
chore(): docs
  • Loading branch information
Samyssmile authored Oct 10, 2023
2 parents f0c0ae8 + 516641f commit f7d32aa
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 10 deletions.
5 changes: 2 additions & 3 deletions example/src/main/java/de/example/benchmark/Benchmark.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
import de.example.data.seaborn.Penguin;
import de.example.data.seaborn.SeabornDataProcessor;

import java.util.ArrayList;
import java.util.Map;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;

/**
Expand Down
26 changes: 26 additions & 0 deletions lib/src/main/java/de/edux/ml/knn/KnnClassifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,39 @@
import java.util.Arrays;
import java.util.PriorityQueue;

/**
* The {@code KnnClassifier} class provides an implementation of the k-Nearest Neighbors algorithm for classification tasks.
* It stores the training dataset and predicts the label for new data points based on the majority label of its k-nearest neighbors in the feature space.
* Distance between data points is computed using the Euclidean distance metric.
* Optionally, predictions can be weighted by the inverse of the distance to give closer neighbors higher influence.
*
* <p>Example usage:</p>
* <pre>{@code
* int k = 3; // Specify the number of neighbors to consider
* KnnClassifier knn = new KnnClassifier(k);
* knn.train(trainingFeatures, trainingLabels);
*
* double[] prediction = knn.predict(inputFeatures);
* double accuracy = knn.evaluate(testFeatures, testLabels);
* }</pre>
*
* <p>Note: The label arrays should be in one-hot encoding format.</p>
*
*
*/
public class KnnClassifier implements Classifier {
Logger LOG = LoggerFactory.getLogger(KnnClassifier.class);
private double[][] trainFeatures;
private double[][] trainLabels;
private int k;
private static final double EPSILON = 1e-10;

/**
* Initializes a new instance of {@code KnnClassifier} with specified k.
*
* @param k an integer value representing the number of neighbors to consider during classification
* @throws IllegalArgumentException if k is not a positive integer
*/
public KnnClassifier(int k) {
if (k <= 0) {
throw new IllegalArgumentException("k must be a positive integer");
Expand Down
40 changes: 38 additions & 2 deletions lib/src/main/java/de/edux/ml/nn/network/MultilayerPerceptron.java
Original file line number Diff line number Diff line change
@@ -1,16 +1,52 @@
package de.edux.ml.nn.network;

import de.edux.api.Classifier;
import de.edux.functions.initialization.Initialization;
import de.edux.ml.nn.Neuron;
import de.edux.functions.activation.ActivationFunction;
import de.edux.ml.nn.Neuron;
import de.edux.ml.nn.config.NetworkConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;

/**
* The {@code MultilayerPerceptron} class represents a simple feedforward neural network,
* which consists of input, hidden, and output layers. It implements the {@code Classifier}
* interface, facilitating both the training and prediction processes on a given dataset.
*
* <p>This implementation utilizes a backpropagation algorithm for training the neural network
* to adjust weights and biases, considering a set configuration defined by {@link NetworkConfiguration}.
* The network's architecture is multi-layered, comprising one or more hidden layers in addition
* to the input and output layers. Neurons within these layers utilize activation functions defined
* per layer through the configuration.</p>
*
* <p>The training process adjusts the weights and biases of neurons within the network based on
* the error between predicted and expected outputs. Additionally, the implementation provides functionality
* to save and restore the best model achieved during training based on accuracy. Early stopping is applied
* during training to prevent overfitting and unnecessary computational expense by monitoring the performance
* improvement across epochs.</p>
*
* <p>Usage example:</p>
* <pre>
* NetworkConfiguration config = ... ;
* double[][] testFeatures = ... ;
* double[][] testLabels = ... ;
*
* MultilayerPerceptron mlp = new MultilayerPerceptron(config, testFeatures, testLabels);
* mlp.train(features, labels);
*
* double accuracy = mlp.evaluate(testFeatures, testLabels);
* double[] prediction = mlp.predict(singleInput);
* </pre>
*
* <p>Note: This implementation logs informative messages, such as accuracy per epoch, using SLF4J logging.</p>
*
* @see de.edux.api.Classifier
* @see de.edux.ml.nn.Neuron
* @see de.edux.ml.nn.config.NetworkConfiguration
* @see de.edux.functions.activation.ActivationFunction
*/
public class MultilayerPerceptron implements Classifier {
private static final Logger LOG = LoggerFactory.getLogger(MultilayerPerceptron.class);

Expand Down
4 changes: 1 addition & 3 deletions lib/src/main/java/de/edux/ml/randomforest/Sample.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
package de.edux.ml.randomforest;

public record Sample(double[][] featureSamples, double[][] labelSamples) {

}
public record Sample(double[][] featureSamples, double[][] labelSamples) {}
23 changes: 21 additions & 2 deletions lib/src/main/java/de/edux/ml/svm/SupportVectorMachine.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,29 @@
import java.util.*;
import java.util.stream.Collectors;

/**
* The {@code SupportVectorMachine} class is an implementation of a Support Vector Machine (SVM) classifier, utilizing the one-vs-one strategy for multi-class classification.
* This SVM implementation accepts a kernel function and trains separate binary classifiers for each pair of classes in the training set, using provided kernel function and regularization parameter C.
* During the prediction, each model in the pair casts a vote and the final predicted class is the one that gets the most votes among all binary classifiers.
*
* <p>Example usage:</p>
* <pre>{@code
* SVMKernel kernel = ... ; // Define an appropriate SVM kernel function
* double c = ... ; // Define an appropriate regularization parameter
*
* SupportVectorMachine svm = new SupportVectorMachine(kernel, c);
* svm.train(trainingFeatures, trainingLabels);
*
* double[] prediction = svm.predict(inputFeatures);
* double accuracy = svm.evaluate(testFeatures, testLabels);
* }</pre>
*
* <p>Note: Label arrays are expected to be in one-hot encoding format and will be internally converted to single label format for training.</p>
*
* @see de.edux.api.Classifier
*/
public class SupportVectorMachine implements Classifier {

private static final Logger LOG = LoggerFactory.getLogger(SupportVectorMachine.class);

private final SVMKernel kernel;
private final double c;
private final Map<String, SVMModel> models;
Expand Down

0 comments on commit f7d32aa

Please sign in to comment.