Skip to content

Commit

Permalink
Leaky ReLUs with test case
Browse files Browse the repository at this point in the history
Original patch by @echoline (libfann#105 (comment)) plus a test case.
  • Loading branch information
DrDub committed Dec 6, 2023
1 parent 8409b42 commit 44bfcee
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 8 deletions.
17 changes: 12 additions & 5 deletions src/fann.c
Original file line number Diff line number Diff line change
Expand Up @@ -683,13 +683,20 @@ FANN_EXTERNAL fann_type *FANN_API fann_run(struct fann *ann, fann_type *input) {
neuron_it->value = neuron_sum;
break;
case FANN_LINEAR_PIECE:
neuron_it->value = (fann_type)(
(neuron_sum < 0) ? 0 : (neuron_sum > multiplier) ? multiplier : neuron_sum);
neuron_it->value = (fann_type)((neuron_sum < 0) ? 0
: (neuron_sum > multiplier) ? multiplier
: neuron_sum);
break;
case FANN_LINEAR_PIECE_SYMMETRIC:
neuron_it->value = (fann_type)((neuron_sum < -multiplier)
? -multiplier
: (neuron_sum > multiplier) ? multiplier : neuron_sum);
neuron_it->value = (fann_type)((neuron_sum < -multiplier) ? -multiplier
: (neuron_sum > multiplier) ? multiplier
: neuron_sum);
break;
case FANN_LINEAR_PIECE_LEAKY:
neuron_it->value = (fann_type)((neuron_sum < 0) ? 0.01 * neuron_sum : neuron_sum);
break;
case FANN_LINEAR_PIECE_RECT:
neuron_it->value = (fann_type)((neuron_sum < 0) ? 0 : neuron_sum);
break;
case FANN_ELLIOT:
case FANN_ELLIOT_SYMMETRIC:
Expand Down
2 changes: 2 additions & 0 deletions src/fann_cascade.c
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,8 @@ fann_type fann_train_candidates_epoch(struct fann *ann, struct fann_train_data *
case FANN_GAUSSIAN_STEPWISE:
case FANN_ELLIOT:
case FANN_LINEAR_PIECE:
case FANN_LINEAR_PIECE_LEAKY:
case FANN_LINEAR_PIECE_RECT:
case FANN_SIN:
case FANN_COS:
break;
Expand Down
6 changes: 6 additions & 0 deletions src/fann_train.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ fann_type fann_activation_derived(unsigned int activation_function, fann_type st
case FANN_LINEAR_PIECE:
case FANN_LINEAR_PIECE_SYMMETRIC:
return (fann_type)fann_linear_derive(steepness, value);
case FANN_LINEAR_PIECE_LEAKY:
return (fann_type)((value < 0) ? steepness * 0.01 : steepness);
case FANN_LINEAR_PIECE_RECT:
return (fann_type)((value < 0) ? 0 : steepness);
case FANN_SIGMOID:
case FANN_SIGMOID_STEPWISE:
value = fann_clip(value, 0.01f, 0.99f);
Expand Down Expand Up @@ -125,6 +129,8 @@ fann_type fann_update_MSE(struct fann *ann, struct fann_neuron *neuron, fann_typ
case FANN_GAUSSIAN_STEPWISE:
case FANN_ELLIOT:
case FANN_LINEAR_PIECE:
case FANN_LINEAR_PIECE_LEAKY:
case FANN_LINEAR_PIECE_RECT:
case FANN_SIN:
case FANN_COS:
break;
Expand Down
6 changes: 6 additions & 0 deletions src/include/fann_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ __doublefann_h__ is not defined
case FANN_GAUSSIAN_STEPWISE: \
result = 0; \
break; \
case FANN_LINEAR_PIECE_LEAKY: \
result = (fann_type)((value < 0) ? value * 0.01 : value); \
break; \
case FANN_LINEAR_PIECE_RECT: \
result = (fann_type)((value < 0) ? 0 : value); \
break; \
}

#endif
18 changes: 16 additions & 2 deletions src/include/fann_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,16 @@ static char const *const FANN_TRAIN_NAMES[] = {"FANN_TRAIN_INCREMENTAL", "FANN_T
* y = cos(x*s)/2+0.5
* d = s*-sin(x*s)/2
FANN_LINEAR_PIECE_LEAKY - leaky ReLU
* span: -inf < y < inf
* y = x<0? 0.01*x: x
* d = x<0? 0.01: 1
FANN_LINEAR_PIECE_RECT - ReLU
* span: -inf < y < inf
* y = x<0? 0: x
* d = x<0? 0: 1
See also:
<fann_set_activation_function_layer>, <fann_set_activation_function_hidden>,
<fann_set_activation_function_output>, <fann_set_activation_steepness>,
Expand Down Expand Up @@ -223,7 +233,9 @@ enum fann_activationfunc_enum {
FANN_SIN_SYMMETRIC,
FANN_COS_SYMMETRIC,
FANN_SIN,
FANN_COS
FANN_COS,
FANN_LINEAR_PIECE_LEAKY,
FANN_LINEAR_PIECE_RECT
};

/* Constant: FANN_ACTIVATIONFUNC_NAMES
Expand Down Expand Up @@ -254,7 +266,9 @@ static char const *const FANN_ACTIVATIONFUNC_NAMES[] = {"FANN_LINEAR",
"FANN_SIN_SYMMETRIC",
"FANN_COS_SYMMETRIC",
"FANN_SIN",
"FANN_COS"};
"FANN_COS",
"FANN_LINEAR_PIECE_LEAKY",
"FANN_LINEAR_PIECE_RECT"};

/* Enum: fann_errorfunc_enum
Error function used during training.
Expand Down
14 changes: 13 additions & 1 deletion src/include/fann_data_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ enum training_algorithm_enum {
* y = cos(x*s)
* d = s*-sin(x*s)
FANN_LINEAR_PIECE_LEAKY - leaky ReLU
* span: -inf < y < inf
* y = x<0? 0.01*x: x
* d = x<0? 0.01: 1
FANN_LINEAR_PIECE_RECT - ReLU
* span: -inf < y < inf
* y = x<0? 0: x
* d = x<0? 0: 1
See also:
<neural_net::set_activation_function_hidden>,
<neural_net::set_activation_function_output>
Expand All @@ -220,7 +230,9 @@ enum activation_function_enum {
LINEAR_PIECE,
LINEAR_PIECE_SYMMETRIC,
SIN_SYMMETRIC,
COS_SYMMETRIC
COS_SYMMETRIC,
LINEAR_PIECE_LEAKY,
LINEAR_PIECE_RECT
};

/* Enum: network_type_enum
Expand Down
13 changes: 13 additions & 0 deletions tests/fann_test_train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ TEST_F(FannTestTrain, TrainOnDateSimpleXor) {
EXPECT_LT(net.test_data(data), 0.001);
}

TEST_F(FannTestTrain, TrainOnReLUSimpleXor) {
neural_net net(LAYER, 3, 2, 3, 1);

data.set_train_data(4, 2, xorInput, 1, xorOutput);
net.set_activation_function_hidden(FANN::LINEAR_PIECE_RECT);
net.set_activation_steepness_hidden(1.0);
net.train_on_data(data, 100, 100, 0.001);

EXPECT_LT(net.get_MSE(), 0.001);
EXPECT_LT(net.test_data(data), 0.001);
}

TEST_F(FannTestTrain, TrainSimpleIncrementalXor) {
neural_net net(LAYER, 3, 2, 3, 1);

Expand All @@ -41,3 +53,4 @@ TEST_F(FannTestTrain, TrainSimpleIncrementalXor) {

EXPECT_LT(net.get_MSE(), 0.01);
}

0 comments on commit 44bfcee

Please sign in to comment.