Skip to content

Commit

Permalink
Update functions
Browse files Browse the repository at this point in the history
  • Loading branch information
hshindo committed Dec 14, 2017
1 parent df538fa commit 557b42d
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/Merlin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ end
include("functions/activation.jl")
include("functions/argmax.jl")
include("functions/blas.jl")
include("functions/cat.jl")
include("functions/concat.jl")
include("functions/conv.jl")
include("functions/dropout.jl")
include("functions/embeddings.jl")
Expand Down
14 changes: 7 additions & 7 deletions src/functions/cat.jl → src/functions/concat.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import Base.cat
export concat

"""
cat(dim::Int, xs::Var...)
concat(dim::Int, xs::Var...)
Concatenate arrays over the given dimension.
# Example
```julia
x1 = Var(rand(Float32,4,3))
x2 = Var(rand(Float32,4,5))
y = cat(2, x1, x2)
y = concat(2, x1, x2)
```
"""
function cat(dim::Int, xs::Var...)
function concat(dim::Int, xs::Var...)
y = cat(dim, map(x -> x.data, xs)...)
Var(y, cat, (dim,xs...))
Var(y, concat, (dim,xs...))
end
cat(dim::Int, xs::Node...; name="") = Node(cat, (dim,xs...), name)
concat(dim::Int, xs::Node...; name="") = Node(concat, (dim,xs...), name)

function addgrad!(y::Var, ::typeof(cat), dim::Int, xs::Var...)
function addgrad!(y::Var, ::typeof(concat), dim::Int, xs::Var...)
T, N = eltype(y), ndims(y)
offset = 1
for x in xs
Expand Down
15 changes: 4 additions & 11 deletions src/functions/embeddings.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
export embeddings, lookup

function embeddings(::Type{T}, insize::Int, outsize::Int; init_w=Normal(0,0.01)) where T
w = init_w(T, outsize, insize)
[zerograd(w[:,i]) for i=1:size(w,2)]
function embeddings(::Type{T}, insize::Int, outsize::Int; init_W=Normal(0,0.01)) where T
W = init_W(T, outsize, insize)
[zerograd(W[:,i]) for i=1:size(W,2)]
end

function lookup(embeds::Vector{Var}, x::Var)
y = lookup(embeds, x.data)
xs = map(i -> embeds[i], vec(x.data))
Var(y, lookup, (xs,))
end

function lookup(w::Var, x::Var)
@assert w.grad == nothing
y = lookup(w.data, x.data)
Var(y, lookup, ())
end
lookup(embeds::Vector{Var}, x::Node; name="") = Node(lookup, (embeds,x), name)
lookup(w::Var, x::Node; name="") = Node(lookup, (w,x), name)
lookup(embeds::Node, x::Node; name="") = Node(lookup, (embeds,x), name)

function lookup(embeds::Vector{Var}, x::Array{Int})
e1 = embeds[1].data
Expand Down
5 changes: 3 additions & 2 deletions src/functions/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ function Linear(::Type{T}, insize::Int, outsize::Int; init_W=Xavier(), init_b=Fi
b = init_b(T, outsize)
Linear(zerograd(W), zerograd(b))
end
(f::Linear)(x) = linear(x, f.W, f.b)
(f::Linear)(x::Var) = linear(x, f.W, f.b)
(f::Linear)(x::Node) = linear(x, Node(f.W), Node(f.b))

function linear(x::Var, W::Var, b::Var)
y = linear(x.data, W.data, b.data)
Var(y, linear, (x,W,b))
end
linear(x::Node, W::Var, b::Var; name="") = Node(linear, (x,W,b), name)
linear(x::Node, W::Node, b::Node; name="") = Node(linear, (x,W,b), name)

function linear(x::Matrix, W::Matrix, b)
y = BLAS.gemm('T', 'N', W, x)
Expand Down
13 changes: 6 additions & 7 deletions src/functions/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ function recurrent(f, x::Var, batchdims::Vector{Int}, h0::Var; rev=false)
i += rev ? batchdims[p]-t : t-1
push!(xts, x[:,i:i])
end
xt = cat(2, xts...)
xt = concat(2, xts...)
if size(h,2) < size(xt,2)
@assert size(h,2) == 1
h = cat(2, ntuple(_ -> h, size(xt,2))...)
h = concat(2, ntuple(_ -> h, size(xt,2))...)
elseif size(h,2) > size(xt,2)
h = h[:,1:size(xt,2)]
end
xt = cat(1, xt, h)
xt = concat(1, xt, h)
h = f(xt)
for j = 1:length(perm)
p = perm[j]
Expand All @@ -36,7 +36,7 @@ function recurrent(f, x::Var, batchdims::Vector{Int}, h0::Var; rev=false)
hs[i] = h[:,j:j]
end
end
cat(2, hs...)
concat(2, hs...)
end

doc"""
Expand Down Expand Up @@ -85,7 +85,6 @@ function LSTM(::Type{T}, insize::Int, outsize::Int; init_W=Xavier(), init_U=Orth
U = init_U(T, insize, 4outsize)
WU = cat(1, W, U)
b = zeros(T, 4outsize)
b[1:outsize] = ones(T, outsize) # forget gate initializes to 1
h0 = zeros(T, outsize, 1)
c0 = zeros(T, outsize, 1)
LSTM(zerograd(WU), zerograd(b), zerograd(h0), zerograd(c0))
Expand All @@ -102,7 +101,7 @@ function (lstm::LSTM)(x::Var, batchdims; rev=false)
o = sigmoid(a[2n+1:3n,:])
if size(c,2) < size(xt,2)
@assert size(c,2) == 1
c = cat(2, ntuple(_ -> c, size(xt,2))...)
c = concat(2, ntuple(_ -> c, size(xt,2))...)
elseif size(c,2) > size(xt,2)
c = c[:,1:size(xt,2)]
end
Expand Down Expand Up @@ -134,5 +133,5 @@ end
function (bilstm::BiLSTM)(x::Var, batchdims)
h1 = bilstm.fwd(x, batchdims)
h2 = bilstm.bwd(x, batchdims, rev=true)
cat(1, h1, h2)
concat(1, h1, h2)
end
4 changes: 2 additions & 2 deletions src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mutable struct Node
id::Int
end

Node(name::String) = Node(nothing, (), name)
Node(f=nothing; name="") = Node(f, (), name)
Node(f, args, name) = Node(f, args, name, 0)

struct Graph
Expand Down Expand Up @@ -43,7 +43,7 @@ function (g::Graph)(xs::Pair...)
for i = 1:length(g.nodes)
node = g.nodes[i]
if isempty(node.args)
isassigned(temps,i) || (temps[i] = node)
isassigned(temps,i) || (temps[i] = node.f)
else
args = map(node.args) do arg
isa(arg,Node) ? temps[arg.id] : arg
Expand Down
4 changes: 2 additions & 2 deletions test/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ end
@testgrad BLAS.gemv('T',1,A,x) A x
end

@testset "cat" for i = 1:5
@testset "concat" for i = 1:5
x1 = zerograd(randn(T,10,5,2))
x2 = zerograd(randn(T,10,5,2))
for dim = 1:3
@testgrad cat(dim,x1,x2) x1 x2
@testgrad concat(dim,x1,x2) x1 x2
end
end

Expand Down

0 comments on commit 557b42d

Please sign in to comment.