From 6aa5399a62be4f8a803676c644abfaaf69a2fb1d Mon Sep 17 00:00:00 2001 From: Omar Shrit Date: Tue, 28 Nov 2023 19:46:29 +0100 Subject: [PATCH 1/2] Start working on bandicoot example Signed-off-by: Omar Shrit --- mnist_simple_coot/Makefile | 43 ++++++ mnist_simple_coot/mnist_simple_coot.cpp | 186 ++++++++++++++++++++++++ 2 files changed, 229 insertions(+) create mode 100644 mnist_simple_coot/Makefile create mode 100644 mnist_simple_coot/mnist_simple_coot.cpp diff --git a/mnist_simple_coot/Makefile b/mnist_simple_coot/Makefile new file mode 100644 index 00000000..8a8f8e09 --- /dev/null +++ b/mnist_simple_coot/Makefile @@ -0,0 +1,43 @@ +# This is a simple Makefile used to build the example source code. +# This example might requires some modifications in order to work correctly on +# your system. +# This example trains mlpack neural network on GPU via OpenCL or CUDA, and it uses +# the bandicoot library. + +TARGET := mnist_simple_coot +SRC := mnist_simple_coot.cpp +LIBS_NAME := bandicoot + +CXX := g++ +CXXFLAGS += -std=c++14 -Wall -Wextra -O3 -DNDEBUG -fopenmp +# Use these CXXFLAGS instead if you want to compile with debugging symbols and +# without optimizations. +# CXXFLAGS += -std=c++14 -Wall -Wextra -g -O0 + +LDFLAGS += -fopenmp +# Add header directories for any includes that aren't on the +# default compiler search path. +INCLFLAGS := -I . +# If you have mlpack or ensmallen installed somewhere nonstandard, uncomment and +# update the lines below. +INCLFLAGS += -I/opt/cuda/targets/x86_64-linux/include +# INCLFLAGS += -I/path/to/ensmallen/include/ +CXXFLAGS += $(INCLFLAGS) + +OBJS := $(SRC:.cpp=.o) +LIBS := $(addprefix -l,$(LIBS_NAME)) +CLEAN_LIST := $(TARGET) $(OBJS) + +# default rule +default: all + +$(TARGET): $(OBJS) + $(CXX) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS) + +.PHONY: all +all: $(TARGET) + +.PHONY: clean +clean: + @echo CLEAN $(CLEAN_LIST) + @rm -f $(CLEAN_LIST) diff --git a/mnist_simple_coot/mnist_simple_coot.cpp b/mnist_simple_coot/mnist_simple_coot.cpp new file mode 100644 index 00000000..c9708e1d --- /dev/null +++ b/mnist_simple_coot/mnist_simple_coot.cpp @@ -0,0 +1,186 @@ +/** + * An example of using Feed Forward Neural Network (FFN) for + * solving Digit Recognizer problem from Kaggle website. + * + * The full description of a problem as well as datasets for training + * and testing are available here https://www.kaggle.com/c/digit-recognizer + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + * + * @author Eugene Freyman + * @author Omar Shrit + */ +#define MLPACK_ENABLE_ANN_SERIALIZATION +#include +#include + +#if ((ENS_VERSION_MAJOR < 2) || \ + ((ENS_VERSION_MAJOR == 2) && (ENS_VERSION_MINOR < 13))) + #error "need ensmallen version 2.13.0 or later" +#endif + +using namespace mlpack; +using namespace std; + +coot::Row getLabels(coot::mat predOut) +{ + coot::Row predLabels(predOut.n_cols); + for (coot::uword i = 0; i < predOut.n_cols; ++i) + { + predLabels(i) = predOut.col(i).index_max(); + } + return predLabels; +} + +int main() +{ + // Dataset is randomly split into validation + // and training parts in the following ratio. + constexpr double RATIO = 0.1; + // The number of neurons in the first layer. + constexpr int H1 = 200; + // The number of neurons in the second layer. + constexpr int H2 = 100; + // Step size of the optimizer. + const double STEP_SIZE = 5e-3; + // Number of data points in each iteration of SGD + const size_t BATCH_SIZE = 64; + // Allow up to 50 epochs, unless we are stopped early by EarlyStopAtMinLoss. + const int EPOCHS = 50; + + // Labeled dataset that contains data for training is loaded from CSV file, + // rows represent features, columns represent data points. + arma::mat dataset; + data::Load("../data/mnist_train.csv", dataset, true); + + // Originally on Kaggle dataset CSV file has header, so it's necessary to + // get rid of the this row, in Armadillo representation it's the first column. + arma::mat headerLessDataset = + dataset.submat(0, 1, dataset.n_rows - 1, dataset.n_cols - 1); + + // Splitting the training dataset on training and validation parts. + arma::mat train, valid; + data::Split(headerLessDataset, train, valid, RATIO); + + // Getting training and validating dataset with features only and then + // normalising + const coot::mat trainX = + coot::conv_to::from(train.submat(1, 0, train.n_rows - 1, train.n_cols - 1) / 255.0); + const coot::mat validX = + coot::conv_to::from(valid.submat(1, 0, valid.n_rows - 1, valid.n_cols - 1) / 255.0); + + // Labels should specify the class of a data point and be in the interval [0, + // numClasses). + + // Creating labels for training and validating dataset. + const coot::mat trainY = coot::conv_to::from(train.row(0)); + const coot::mat validY = coot::conv_to::from(valid.row(0)); + + // Specifying the NN model. NegativeLogLikelihood is the output layer that + // is used for classification problem. GlorotInitialization means that + // initial weights in neurons are a uniform gaussian distribution. + FFN, GlorotInitialization> model; + // This is intermediate layer that is needed for connection between input + // data and relu layer. Parameters specify the number of input features + // and number of neurons in the next layer. + model.Add(H1); + // The first relu layer. + model.Add(); + // Intermediate layer between relu layers. + model.Add(H2); + // The second relu layer. + model.Add(); + // Dropout layer for regularization. First parameter is the probability of + // setting a specific value to 0. + model.Add(0.2); + // Intermediate layer. + model.Add(10); + // LogSoftMax layer is used together with NegativeLogLikelihood for mapping + // output values to log of probabilities of being a specific class. + model.Add(); + + cout << "Start training ..." << endl; + + // Set parameters for the Adam optimizer. + ens::Adam optimizer( + STEP_SIZE, // Step size of the optimizer. + BATCH_SIZE, // Batch size. Number of data points that are used in each + // iteration. + 0.9, // Exponential decay rate for the first moment estimates. + 0.999, // Exponential decay rate for the weighted infinity norm estimates. + 1e-8, // Value used to initialise the mean squared gradient parameter. + EPOCHS * trainX.n_cols, // Max number of iterations. + 1e-8, // Tolerance. + true); + + // Declare callback to store best training weights. + ens::StoreBestCoordinates bestCoordinates; + + // Train neural network. If this is the first iteration, weights are + // random, using current values as starting point otherwise. + model.Train(trainX, + trainY, + optimizer, + ens::PrintLoss(), + ens::ProgressBar(), + // Stop the training using Early Stop at min loss. + ens::EarlyStopAtMinLoss( + [&](const coot::mat& /* param */) + { + double validationLoss = model.Evaluate(validX, validY); + cout << "Validation loss: " << validationLoss << "." + << endl; + return validationLoss; + }), + // Store best coordinates (neural network weights) + bestCoordinates); + + // Save the best training weights into the model. + model.Parameters() = bestCoordinates.BestCoordinates(); + + coot::mat predOut; + // Getting predictions on training data points. + model.Predict(trainX, predOut); + // Calculating accuracy on training data points. + coot::Row predLabels = getLabels(predOut); + double trainAccuracy = + coot::accu(predLabels == trainY) / (double) trainY.n_elem * 100; + // Getting predictions on validating data points. + model.Predict(validX, predOut); + // Calculating accuracy on validating data points. + predLabels = getLabels(predOut); + double validAccuracy = + coot::accu(predLabels == validY) / (double) validY.n_elem * 100; + + cout << "Accuracy: train = " << trainAccuracy << "%," + << "\t valid = " << validAccuracy << "%" << endl; + + data::Save("model.bin", "model", model, false); + + // Loading test dataset (the one whose predicted labels + // should be sent to kaggle website). + //data::Load("../data/mnist_test.csv", dataset, true); + //coot::mat testY = dataset.row(0); + //dataset.shed_row(0); // Strip labels before predicting. + //dataset /= 255.0; // Apply the same normalization as to the training data. + + //cout << "Predicting on test set..." << endl; + //coot::mat testPredOut; + //// Getting predictions on test data points. + //model.Predict(dataset, testPredOut); + //// Generating labels for the test dataset. + //coot::Row testPred = getLabels(testPredOut); + + //double testAccuracy = coot::accu(testPred == testY) / + //(double) testY.n_elem * 100; + //cout << "Accuracy: test = " << testAccuracy << "%" << endl; + + //cout << "Saving predicted labels to \"results.csv\" ..." << endl; + //testPred.save("results.csv", coot::csv_ascii); + + //cout << "Neural network model is saved to \"model.bin\"" << endl; + cout << "Finished" << endl; +} From 8c519c1b2fe56723d95cbe2361ba5137477c27ff Mon Sep 17 00:00:00 2001 From: Omar Shrit Date: Fri, 26 Jan 2024 12:29:28 +0100 Subject: [PATCH 2/2] just commiting the modifications so I know where am I Signed-off-by: Omar Shrit --- mnist_simple_coot/Makefile | 3 ++- mnist_simple_coot/mnist_simple_coot.cpp | 24 +++++++++++++----------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/mnist_simple_coot/Makefile b/mnist_simple_coot/Makefile index 8a8f8e09..2de4b93c 100644 --- a/mnist_simple_coot/Makefile +++ b/mnist_simple_coot/Makefile @@ -21,7 +21,8 @@ INCLFLAGS := -I . # If you have mlpack or ensmallen installed somewhere nonstandard, uncomment and # update the lines below. INCLFLAGS += -I/opt/cuda/targets/x86_64-linux/include -# INCLFLAGS += -I/path/to/ensmallen/include/ +INCLFLAGS += -I/meta/mlpack/src +#INCLFLAGS += -I/meta/m CXXFLAGS += $(INCLFLAGS) OBJS := $(SRC:.cpp=.o) diff --git a/mnist_simple_coot/mnist_simple_coot.cpp b/mnist_simple_coot/mnist_simple_coot.cpp index c9708e1d..f34f7b94 100644 --- a/mnist_simple_coot/mnist_simple_coot.cpp +++ b/mnist_simple_coot/mnist_simple_coot.cpp @@ -14,8 +14,10 @@ * @author Omar Shrit */ #define MLPACK_ENABLE_ANN_SERIALIZATION -#include +#define MLPACK_HAS_COOT + #include +#include #if ((ENS_VERSION_MAJOR < 2) || \ ((ENS_VERSION_MAJOR == 2) && (ENS_VERSION_MINOR < 13))) @@ -30,7 +32,7 @@ coot::Row getLabels(coot::mat predOut) coot::Row predLabels(predOut.n_cols); for (coot::uword i = 0; i < predOut.n_cols; ++i) { - predLabels(i) = predOut.col(i).index_max(); + // predLabels(i) = predOut.col(i).index_max(); } return predLabels; } @@ -82,25 +84,25 @@ int main() // Specifying the NN model. NegativeLogLikelihood is the output layer that // is used for classification problem. GlorotInitialization means that // initial weights in neurons are a uniform gaussian distribution. - FFN, GlorotInitialization> model; + FFN, GlorotInitialization, coot::mat> model; // This is intermediate layer that is needed for connection between input // data and relu layer. Parameters specify the number of input features // and number of neurons in the next layer. - model.Add(H1); + model.Add>(H1); // The first relu layer. - model.Add(); + model.Add>(); // Intermediate layer between relu layers. - model.Add(H2); + model.Add>(H2); // The second relu layer. - model.Add(); + model.Add>(); // Dropout layer for regularization. First parameter is the probability of // setting a specific value to 0. - model.Add(0.2); + model.Add>(0.2); // Intermediate layer. - model.Add(10); + model.Add>(10); // LogSoftMax layer is used together with NegativeLogLikelihood for mapping // output values to log of probabilities of being a specific class. - model.Add(); + model.Add>(); cout << "Start training ..." << endl; @@ -127,7 +129,7 @@ int main() ens::PrintLoss(), ens::ProgressBar(), // Stop the training using Early Stop at min loss. - ens::EarlyStopAtMinLoss( + ens::EarlyStopAtMinLossType( [&](const coot::mat& /* param */) { double validationLoss = model.Evaluate(validX, validY);