Skip to content

Commit

Permalink
MSE is Float32 not Float64
Browse files Browse the repository at this point in the history
  • Loading branch information
bararchy committed Sep 27, 2017
1 parent 3b94e07 commit 09fe537
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
16 changes: 10 additions & 6 deletions spec/network_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ describe Fann::Network do

it "trains on single data" do
ann = Fann::Network::Standard.new(2, [2, 2], 1)
ann.randomzie_weights(0.0, 1.0)
3000.times do
ann.train_single([1.0, 0.0], [0.5])
end
Expand All @@ -33,10 +34,11 @@ describe Fann::Network do

it "trains and evaluate single data" do
ann = Fann::Network::Standard.new(2, [2], 1)
ann.randomzie_weights(0.0, 1.0)
3000.times do
ann.train_single([1.0, 0.1], [0.5])
ann.train_single([1.0, 0.0], [0.5])
end
result = ann.run([1.0, 0.1])
result = ann.run([1.0, 0.0])
ann.close
(result < [0.55] && result > [0.45]).should be_true
end
Expand All @@ -47,9 +49,10 @@ describe Fann::Network do
output = [[0.0], [1.0], [1.0], [0.0]]
train_data = Fann::TrainData.new(input, output)
data = train_data.train_data
ann.train_algorithem(LibFANN::TrainEnum::TrainRprop)
ann.set_hidden_layer_activation_func(LibFANN::ActivationfuncEnum::LeakyRelu)
ann.set_output_layer_activation_func(LibFANN::ActivationfuncEnum::LeakyRelu)
# ann.train_algorithem(LibFANN::TrainEnum::TrainSarprop)
ann.randomzie_weights(0.0, 1.0)
# ann.set_hidden_layer_activation_func(LibFANN::ActivationfuncEnum::LeakyRelu)
# ann.set_output_layer_activation_func(LibFANN::ActivationfuncEnum::LeakyRelu)
if data
ann.train_batch(data, {:max_runs => 8000, :desired_mse => 0.001, :log_each => 1000})
end
Expand All @@ -65,8 +68,9 @@ describe Fann::Network do
train_data = Fann::TrainData.new(input, output)
data = train_data.train_data
ann.train_algorithem(LibFANN::TrainEnum::TrainRprop)
ann.randomzie_weights(0.0, 1.0)
if data
ann.train_batch(data, {:max_neurons => 500, :desired_mse => 0.1, :log_each => 10})
ann.train_batch(data, {:max_neurons => 500, :desired_mse => 0.001, :log_each => 10})
end
result = ann.run([1.0, 1.0])
ann.close
Expand Down
4 changes: 2 additions & 2 deletions src/crystal-fann/cascade_network.cr
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ module Fann
LibFANN.set_activation_function_hidden(@nn, func)
end

def randomzie_weights(min : Float32, max : Float32)
def randomzie_weights(min : Float64, max : Float64)
# randomize_weights = fann_randomize_weights(ann : Fann*, min_weight : Type, max_weight : Type)
LibFANN.randomize_weights(@nn, min, max)
end

def train_batch(train_data : Pointer(LibFANN::TrainData), opts = {:max_neurons => 500, :desired_mse => 0.01_f64, :log_each => 10})
# fun cascadetrain_on_data = fann_cascadetrain_on_data(ann : Fann*, data : TrainData*, max_neurons : LibC::UInt, neurons_between_reports : LibC::UInt, desired_error : LibC::Double)
LibFANN.cascade_train_on_data(@nn, train_data, opts[:max_neurons], opts[:log_each], opts[:desired_mse])
LibFANN.cascade_train_on_data(@nn, train_data, opts[:max_neurons], opts[:log_each], opts[:desired_mse].to_f32)
end

def run(input : Array(Float64))
Expand Down
8 changes: 4 additions & 4 deletions src/crystal-fann/lib_fann.cr
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ lib LibFANN
fun get_mse = fann_get_MSE(ann : Fann*) : LibC::Double
fun get_bit_fail = fann_get_bit_fail(ann : Fann*) : LibC::UInt
fun reset_mse = fann_reset_MSE(ann : Fann*)
fun train_on_data = fann_train_on_data(ann : Fann*, data : TrainData*, max_epochs : LibC::UInt, epochs_between_reports : LibC::UInt, desired_error : LibC::Double)
fun train_on_file = fann_train_on_file(ann : Fann*, filename : LibC::Char*, max_epochs : LibC::UInt, epochs_between_reports : LibC::UInt, desired_error : LibC::Double)
fun train_on_data = fann_train_on_data(ann : Fann*, data : TrainData*, max_epochs : LibC::UInt, epochs_between_reports : LibC::UInt, desired_error : LibC::Float)
fun train_on_file = fann_train_on_file(ann : Fann*, filename : LibC::Char*, max_epochs : LibC::UInt, epochs_between_reports : LibC::UInt, desired_error : LibC::Float)
fun train_epoch = fann_train_epoch(ann : Fann*, data : TrainData*) : LibC::Double
fun train_epoch_lw = fann_train_epoch_lw(ann : Fann*, data : TrainData*, label_weight : Type*) : LibC::Double
fun train_epoch_irpropm_gradient = fann_train_epoch_irpropm_gradient(ann : Fann*, data : TrainData*, error_function : (Type*, Type*, LibC::Int, Void* -> Type), x3 : Void*) : LibC::Double
Expand Down Expand Up @@ -381,8 +381,8 @@ lib LibFANN
fun set_sarprop_step_error_shift = fann_set_sarprop_step_error_shift(ann : Fann*, sarprop_step_error_shift : LibC::Double)
fun get_sarprop_temperature = fann_get_sarprop_temperature(ann : Fann*) : LibC::Double
fun set_sarprop_temperature = fann_set_sarprop_temperature(ann : Fann*, sarprop_temperature : LibC::Double)
fun cascade_train_on_data = fann_cascadetrain_on_data(ann : Fann*, data : TrainData*, max_neurons : LibC::UInt, neurons_between_reports : LibC::UInt, desired_error : LibC::Double)
fun cascade_train_on_file = fann_cascadetrain_on_file(ann : Fann*, filename : LibC::Char*, max_neurons : LibC::UInt, neurons_between_reports : LibC::UInt, desired_error : LibC::Double)
fun cascade_train_on_data = fann_cascadetrain_on_data(ann : Fann*, data : TrainData*, max_neurons : LibC::UInt, neurons_between_reports : LibC::UInt, desired_error : LibC::Float)
fun cascade_train_on_file = fann_cascadetrain_on_file(ann : Fann*, filename : LibC::Char*, max_neurons : LibC::UInt, neurons_between_reports : LibC::UInt, desired_error : LibC::Float)
fun get_cascade_output_change_fraction = fann_get_cascade_output_change_fraction(ann : Fann*) : LibC::Double
fun set_cascade_output_change_fraction = fann_set_cascade_output_change_fraction(ann : Fann*, cascade_output_change_fraction : LibC::Double)
fun get_cascade_output_stagnation_epochs = fann_get_cascade_output_stagnation_epochs(ann : Fann*) : LibC::UInt
Expand Down
4 changes: 2 additions & 2 deletions src/crystal-fann/standard_network.cr
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ module Fann
LibFANN.set_activation_function_output(@nn, func)
end

def randomzie_weights(min : Float32, max : Float32)
def randomzie_weights(min : Float64, max : Float64)
# randomize_weights = fann_randomize_weights(ann : Fann*, min_weight : Type, max_weight : Type)
LibFANN.randomize_weights(@nn, min, max)
end
Expand All @@ -50,7 +50,7 @@ module Fann

def train_batch(train_data : Pointer(LibFANN::TrainData), opts = {:max_runs => 200, :desired_mse => 0.01_f64, :log_each => 1})
# train_on_data = fann_train_on_data(ann : Fann*, data : TrainData*, max_epochs : LibC::UInt, epochs_between_reports : LibC::UInt, desired_error : LibC::Double)
LibFANN.train_on_data(@nn, train_data, opts[:max_runs], opts[:log_each], opts[:desired_mse])
LibFANN.train_on_data(@nn, train_data, opts[:max_runs], opts[:log_each], opts[:desired_mse].to_f32)
end

def train_batch_multicore(train_data : Pointer(LibFANN::TrainData), threads : Int32, opts = {:max_runs => 200, :desired_mse => 0.01_f64, :log_each => 1})
Expand Down

0 comments on commit 09fe537

Please sign in to comment.