Skip to content

btbujiangjun/abcdl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

abcdl

A c++ light Deep Learning framework for ABC, Include DNN, CNN and RNN.

DNN example

1. Configure layers

std::vector<abcdl::dnn::Layer*> layers;
layers.push_back(new abcdl::dnn::InputLayer(784));
layers.push_back(new abcdl::dnn::FullConnLayer(784, 30, new abcdl::framework::SigmoidActivateFunc()));
layers.push_back(new abcdl::dnn::OutputLayer(30, 10, new abcdl::framework::SigmoidActivateFunc(), new abcdl::framework::CrossEntropyCost()));

2. Initailize Network

abcdl::dnn::DNN dnn;
dnn.set_layers(layers);

3. Load training data

abcdl::utils::MnistHelper helper;

abcdl::algebra::Mat train_data;
helper.read_image("data/mnist/train-images-idx3-ubyte", &train_data, 60000);

abcdl::algebra::Mat train_label;
helper.read_vec_label("data/mnist/train-labels-idx1-ubyte", &train_label, 10000);

4. Train network

dnn.train(train_data, train_label);

5. Predict

abcdl::algebra::Mat result;
abcdl::algebra::Mat predict_data;
helper.read_image("data/mnist/t10k-images-idx3-ubyte", &predict_data, 1);
dnn.predict(result, predict_data);

6. Serialize model

const std::string path = "data/dnn.model";
dnn.write_model(path);

7. Deserialize model

dnn.load_model(path);

CNN example

1. Configure layers

std::vector<abcdl::cnn::Layer*> layers;
layers.push_back(new abcdl::cnn::InputLayer(28, 28));
layers.push_back(new abcdl::cnn::ConvolutionLayer(3, 1, 5, new abcdl::framework::SigmoidActivateFunc()));
layers.push_back(new abcdl::cnn::SubSamplingLayer(2, new abcdl::framework::MeanPooling()));
layers.push_back(new abcdl::cnn::ConvolutionLayer(3, 1, 5, new abcdl::framework::SigmoidActivateFunc()));
layers.push_back(new abcdl::cnn::OutputLayer(10, new abcdl::framework::SigmoidActivateFunc(), new abcdl::framework::CrossEntropyCost()));

2. initialize network

abcdl::cnn::CNN cnn;
cnn.set_layers(layers);

3. Load training data

abcdl::utils::MnistHelper helper;

abcdl::algebra::Mat train_data;
helper.read_image("data/mnist/train-images-idx3-ubyte", &train_data, 60000);

abcdl::algebra::Mat train_label;
helper.read_vec_label("data/mnist/train-labels-idx1-ubyte", &train_label, 10000);

4. Train network

cnn.train(train_data, train_label);

5. Predict

abcdl::algebra::Mat result;
abcdl::algebra::Mat predict_data;
helper.read_image("data/mnist/t10k-images-idx3-ubyte", &predict_data, 1);
cnn.predict(result, predict_data);

About

A c++ light Deep learning framework for ABC

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published