Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code refactoring duplicated code #107

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

GCC=gcc
CFLAGS=-I ../src/include
LDFLAGS=-L ../src/
#LDFLAGS=-L ../src/
LDFLAGS=-L ../build/src/

TARGETS = xor_train xor_test xor_test_fixed simple_train \
steepness_train simple_test robot mushroom cascade_train \
scaling_test scaling_train nn-benchmark parallel_train

TARGETS = xor_train xor_test xor_test_fixed simple_train steepness_train simple_test robot mushroom cascade_train scaling_test scaling_train
DEBUG_TARGETS = xor_train_debug xor_test_debug xor_test_fixed_debug cascade_train_debug

all: $(TARGETS)
Expand Down
144 changes: 144 additions & 0 deletions examples/nn-benchmark.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/* This benchmark generates a random set of inputs and outputs
* on a large network and tests the backpropagation speed in order
* for the net to arrive at a certain level of error. */

#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>

#include "fann.h"
#include "parallel_fann.h"

//#define ONLY_TEST

#define NUM_INPUTS 60 //120 //300
#define NUM_OUTPUTS 10 //40 //100

int math_random(int low, int up) {
fann_type r = rand() * (1.0 / (RAND_MAX + 1.0));
r *= (up - low) + 1.0;
return (int)r+low;
}

void gen_dataset(fann_type **inputs, fann_type **outputs, int setsize) {
*inputs = calloc(1, sizeof(fann_type)*setsize*NUM_INPUTS);
*outputs = calloc(1, sizeof(fann_type)*setsize*NUM_OUTPUTS);
int ilen = NUM_INPUTS;
int olen = NUM_OUTPUTS;
int olen_1 = olen - 1;

fann_type *in = *inputs;
fann_type *out = *outputs;
for (int j = 0; j < setsize; j++) {
for (int k = 0; k < ilen; k++) in[k] = rand() & 1;
//int r = rand() & olen_1;
int r = math_random(0, olen_1);
out[r] = 1;
//printf("%d : %d\n", j, r);
//for (int k = 0; k < olen; k++) {
// out[k] = (k == r) ? 1 : 0;
//}
in+= ilen;
out+= olen;
}
}

int main(int argc, char *argv[]) {
const char fn_net[] = "benchmark_float.net";
int setsize = 1000;
unsigned int i = 0, ilen;
unsigned int num_threads = 1;
if(argc == 2)
num_threads = atoi(argv[1]);
#ifndef ONLY_TEST
fann_type *inputs, *outputs;
gen_dataset(&inputs, &outputs, setsize);
#endif
fann_type *test_inputs, *test_outputs;
gen_dataset(&test_inputs, &test_outputs, setsize);

struct fann_train_data *test_data;

#ifndef ONLY_TEST
const unsigned int num_layers = 3;
const unsigned int num_neurons_hidden = NUM_INPUTS*2;
const float desired_error = (const float) 0.0001;
float desired_error_reached;
const unsigned int max_epochs = 3000;
const unsigned int epochs_between_reports = 10;
struct fann *ann;
struct fann_train_data *train_data;

printf("Creating network.\n");

train_data = fann_create_train_array(setsize, NUM_INPUTS, inputs, NUM_OUTPUTS, outputs);

ann = fann_create_standard(num_layers,
train_data->num_input, num_neurons_hidden, train_data->num_output);

printf("Training network.\n");

fann_set_activation_function_hidden(ann, FANN_SIGMOID_SYMMETRIC);
fann_set_activation_function_output(ann, FANN_SIGMOID);
//fann_set_training_algorithm(ann, FANN_TRAIN_INCREMENTAL);
fann_set_training_algorithm(ann, FANN_TRAIN_RPROP);
fann_set_learning_rate(ann, 0.5f);
fann_randomize_weights(ann, -2.0f, 2.0f);

/*fann_set_training_algorithm(ann, FANN_TRAIN_INCREMENTAL); */

//fann_train_on_data(ann, train_data, max_epochs, epochs_between_reports, desired_error);


long before = fann_mstime();
for(i = 1; i <= max_epochs; i++)
{
long start = fann_mstime();
double error = (num_threads > 1)
? fann_train_epoch_irpropm_parallel(ann, train_data, num_threads)
: fann_train_epoch(ann, train_data);
long elapsed = fann_mstime() - start;
printf("Epochs %8d. Current error: %.10f :: %ld\n", i, error, elapsed);
desired_error_reached = fann_desired_error_reached(ann, desired_error);
if(desired_error_reached == 0)
break;
}
printf("Time spent %ld ms\n", fann_mstime()-before);

#else
struct fann *ann = fann_create_from_file(fn_net);
#endif

test_data = fann_create_train_array(setsize, NUM_INPUTS, test_inputs, NUM_OUTPUTS, test_outputs);
ilen = fann_length_train_data(test_data);
printf("Testing network. %d\n", ilen);


fann_reset_MSE(ann);
for(i = 0; i < ilen; i++)
{
fann_test(ann, test_data->input[i], test_data->output[i]);
}

printf("MSE error on test data: %f\n", fann_get_MSE(ann));

#ifndef ONLY_TEST
printf("Saving network.\n");

fann_save(ann, fn_net);
#endif

printf("Cleaning up.\n");
fann_destroy(ann);

#ifndef ONLY_TEST
fann_destroy_train(train_data);
free(inputs);
free(outputs);
#endif
fann_destroy_train(test_data);
free(test_inputs);
free(test_outputs);

return 0;
}
9 changes: 5 additions & 4 deletions examples/parallel_train.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ int main(int argc, const char* argv[])
unsigned int num_threads = 1;
struct fann_train_data *data;
struct fann *ann;
long before;
long before, totaltime;
float error;
unsigned int i;

Expand All @@ -39,13 +39,14 @@ int main(int argc, const char* argv[])
fann_set_activation_function_hidden(ann, FANN_SIGMOID_SYMMETRIC);
fann_set_activation_function_output(ann, FANN_SIGMOID);

before = GetTickCount();
before = fann_mstime();
for(i = 1; i <= max_epochs; i++)
{
error = num_threads > 1 ? fann_train_epoch_irpropm_parallel(ann, data, num_threads) : fann_train_epoch(ann, data);
printf("Epochs %8d. Current error: %.10f\n", i, error);
totaltime = fann_mstime()-before;
printf("Epochs %8d. Current error: %.10f :: %ld ms by epoch\n", i, error, i ? totaltime/i : 0);
}
printf("ticks %d", GetTickCount()-before);
printf("Time spent %ld ms\n", totaltime);

fann_destroy(ann);
fann_destroy_train(data);
Expand Down
2 changes: 2 additions & 0 deletions examples/run-it
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export LD_LIBRARY_PATH=$PWD/../build/src:$LD_LIBRARY_PATH
$*
Loading