Skip to content

Commit

Permalink
Merge pull request #8 from hugoabonizio/feat/add-save-and-load
Browse files Browse the repository at this point in the history
Add save and load methods
  • Loading branch information
bararchy authored Sep 28, 2017
2 parents 09fe537 + c4a74df commit b6b337e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 0 deletions.
42 changes: 42 additions & 0 deletions spec/network_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,46 @@ describe Fann::Network do
ann.close
(result < [0.1]).should be_true
end

context "saving current network" do
it "saves standard networks" do
tempfile = Tempfile.new("foo")
File.size(tempfile.path).should eq 0
ann = Fann::Network::Standard.new(2, [2], 1)
ann.save(tempfile.path)
(File.size(tempfile.path) > 0).should be_true
end

it "saves cascade networks" do
tempfile = Tempfile.new("bar")
File.size(tempfile.path).should eq 0
ann = Fann::Network::Cascade.new(2, 1)
ann.save(tempfile.path)
(File.size(tempfile.path) > 0).should be_true
end
end

context "loading a configuration file" do
it "loads standard networks" do
input = 2
output = 1
tempfile = Tempfile.new("standard")
original = Fann::Network::Standard.new(input, [2], output)
original.save(tempfile.path)
loaded = Fann::Network::Standard.new(tempfile.path)
loaded.input_size.should eq input
loaded.output_size.should eq output
end

it "loads cascade networks" do
input = 2
output = 1
tempfile = Tempfile.new("cascade")
original = Fann::Network::Cascade.new(input, output)
original.save(tempfile.path)
loaded = Fann::Network::Cascade.new(tempfile.path)
loaded.input_size.should eq input
loaded.output_size.should eq output
end
end
end
1 change: 1 addition & 0 deletions spec/spec_helper.cr
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
require "spec"
require "tempfile"
require "../src/crystal-fann"
12 changes: 12 additions & 0 deletions src/crystal-fann/cascade_network.cr
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@ module Fann
module Network
class Cascade
property :nn
getter :input_size
getter :output_size

def initialize(input : Int32, output : Int32)
@output_size = output
@input_size = input
@nn = LibFANN.create_shortcut(2, input, output)
end

def initialize(path : String)
@nn = LibFANN.create_from_file(path)
@input_size = LibFANN.get_num_input(@nn)
@output_size = LibFANN.get_num_output(@nn)
end

def mse
LibFANN.get_mse(@nn)
end
Expand Down Expand Up @@ -43,6 +51,10 @@ module Fann
result = LibFANN.run(@nn, input.to_unsafe)
Slice.new(result, @output_size).to_a
end

def save(path : String) : Int32
LibFANN.save(@nn, path)
end
end
end
end
13 changes: 13 additions & 0 deletions src/crystal-fann/standard_network.cr
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module Fann
module Network
class Standard
property :nn
getter :input_size
getter :output_size

def initialize(input : Int32, hidden : Array(Int32), output : Int32)
@logger = Logger.new(STDOUT)
Expand All @@ -17,6 +19,13 @@ module Fann
@nn = LibFANN.create_standard_array(layers.size, layers.to_unsafe)
end

def initialize(path : String)
@logger = Logger.new(STDOUT)
@nn = LibFANN.create_from_file(path)
@input_size = LibFANN.get_num_input(@nn)
@output_size = LibFANN.get_num_output(@nn)
end

def mse
LibFANN.get_mse(@nn)
end
Expand Down Expand Up @@ -79,6 +88,10 @@ module Fann
result = LibFANN.run(@nn, input.to_unsafe)
Slice.new(result, @output_size).to_a
end

def save(path : String) : Int32
LibFANN.save(@nn, path)
end
end
end
end

0 comments on commit b6b337e

Please sign in to comment.