A c++ light Deep Learning
framework for ABC, Include DNN
, CNN
and RNN
.
std::vector<abcdl::dnn::Layer*> layers;
layers.push_back(new abcdl::dnn::InputLayer(784));
layers.push_back(new abcdl::dnn::FullConnLayer(784, 30, new abcdl::framework::SigmoidActivateFunc()));
layers.push_back(new abcdl::dnn::OutputLayer(30, 10, new abcdl::framework::SigmoidActivateFunc(), new abcdl::framework::CrossEntropyCost()));
abcdl::dnn::DNN dnn;
dnn.set_layers(layers);
abcdl::utils::MnistHelper helper;
abcdl::algebra::Mat train_data;
helper.read_image("data/mnist/train-images-idx3-ubyte", &train_data, 60000);
abcdl::algebra::Mat train_label;
helper.read_vec_label("data/mnist/train-labels-idx1-ubyte", &train_label, 10000);
dnn.train(train_data, train_label);
abcdl::algebra::Mat result;
abcdl::algebra::Mat predict_data;
helper.read_image("data/mnist/t10k-images-idx3-ubyte", &predict_data, 1);
dnn.predict(result, predict_data);
const std::string path = "data/dnn.model";
dnn.write_model(path);
dnn.load_model(path);
std::vector<abcdl::cnn::Layer*> layers;
layers.push_back(new abcdl::cnn::InputLayer(28, 28));
layers.push_back(new abcdl::cnn::ConvolutionLayer(3, 1, 5, new abcdl::framework::SigmoidActivateFunc()));
layers.push_back(new abcdl::cnn::SubSamplingLayer(2, new abcdl::framework::MeanPooling()));
layers.push_back(new abcdl::cnn::ConvolutionLayer(3, 1, 5, new abcdl::framework::SigmoidActivateFunc()));
layers.push_back(new abcdl::cnn::OutputLayer(10, new abcdl::framework::SigmoidActivateFunc(), new abcdl::framework::CrossEntropyCost()));
abcdl::cnn::CNN cnn;
cnn.set_layers(layers);
abcdl::utils::MnistHelper helper;
abcdl::algebra::Mat train_data;
helper.read_image("data/mnist/train-images-idx3-ubyte", &train_data, 60000);
abcdl::algebra::Mat train_label;
helper.read_vec_label("data/mnist/train-labels-idx1-ubyte", &train_label, 10000);
cnn.train(train_data, train_label);
abcdl::algebra::Mat result;
abcdl::algebra::Mat predict_data;
helper.read_image("data/mnist/t10k-images-idx3-ubyte", &predict_data, 1);
cnn.predict(result, predict_data);