Skip to content

Commit

Permalink
Remove var.f
Browse files Browse the repository at this point in the history
  • Loading branch information
hshindo committed Dec 14, 2017
1 parent 557b42d commit ce29ef1
Show file tree
Hide file tree
Showing 16 changed files with 42 additions and 55 deletions.
16 changes: 8 additions & 8 deletions src/functions/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ f(x) = (\max(0,x), \max(0,-x))
# References
* Shang et al., ["Understanding and Improving Convolutional Neural Networks via Concatenated Rectified Linear Units"](https://arxiv.org/abs/1603.05201), arXiv 2016.
"""
crelu(x::Var) = Var(crelu(x.data), crelu, (x,))
crelu(x::Var) = Var(crelu(x.data), (crelu,x))
crelu(x::Node; name="") = Node(crelu, (x,), name)

function crelu(x::Array{T}) where T
Expand Down Expand Up @@ -55,7 +55,7 @@ x & x > 0 \\
```
where ``\alpha=1``.
"""
elu(x::Var) = Var(elu.(x.data), elu, (x,))
elu(x::Var) = Var(elu.(x.data), (elu,x))
elu(x::Node; name="") = Node(elu, (x,), name)
elu(x::T) where T = x > T(0) ? x : exp(x)-1

Expand Down Expand Up @@ -86,7 +86,7 @@ x & x > 0 \\
# References
* Maas et al., ["Rectifier Nonlinearities Improve Neural Network Acoustic Models"](http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf), ICML 2013.
"""
leaky_relu(x::Var, alpha::Float64=0.1) = Var(leaky_relu.(x.data,eltype(x)(alpha)), leaky_relu, (x,alpha))
leaky_relu(x::Var, alpha::Float64=0.1) = Var(leaky_relu.(x.data,eltype(x)(alpha)), (leaky_relu,x,alpha))
leaky_relu(x::Node; name="") = Node(leaky_relu, (x,), name)
leaky_relu(x::T, alpha::T) where T = x >= T(0) ? x : x*alpha

Expand All @@ -109,7 +109,7 @@ Rectified Linear Unit.
f(x) = \max(0, x)
```
"""
relu(x::Var) = Var(relu.(x.data), relu, (x,))
relu(x::Var) = Var(relu.(x.data), (relu,x))
relu(x::Node; name="") = Node(relu, (x,), name)
relu(x::T) where T = max(x, T(0))

Expand Down Expand Up @@ -140,7 +140,7 @@ where ``\lambda=1.0507`` and ``\alpha=1.6733``.
# References
Klambauer et al., ["Self-Normalizing Neural Networks"](https://arxiv.org/abs/1706.02515), NIPS 2017.
"""
selu(x::Var) = Var(selu.(x.data), selu, (x,))
selu(x::Var) = Var(selu.(x.data), (selu,x))
selu(x::Node; name="") = Node(selu, (x,), name)
selu(x::T) where T = x > 0 ? T(1.0507)*x : T(1.0507)*T(1.6733)*(exp(x)-1)

Expand All @@ -165,7 +165,7 @@ Sigmoid logistic function.
f(x) = (1 + \exp(-x))^{-1}
```
"""
sigmoid(x::Var) = Var(sigmoid.(x.data), sigmoid, (x,))
sigmoid(x::Var) = Var(sigmoid.(x.data), (sigmoid,x))
sigmoid(x::Node; name="") = Node(sigmoid, (x,), name)
sigmoid(x::T) where T<:AbstractFloat = 1 / (1 + exp(-x))

Expand Down Expand Up @@ -198,7 +198,7 @@ end
Swish(::Type{T}) where T = Swish(zerograd(ones(T,1)))
(f::Swish)(x) = swish(x, f.beta)

swish(x::Var, beta::Var) = Var(swish.(x.data,beta.data), swish, (x,beta))
swish(x::Var, beta::Var) = Var(swish.(x.data,beta.data), (swish,x,beta))
swish(x::Node, beta::Var; name="") = Node(swish, (x,beta), name)
swish(x::T, beta::T) where T = x * sigmoid(beta*x)

Expand Down Expand Up @@ -226,7 +226,7 @@ doc"""
Hyperbolic tangent function.
"""
tanh(x::Var) = Var(tanh.(x.data), tanh, (x,))
tanh(x::Var) = Var(tanh.(x.data), (tanh,x))
tanh(x::Node; name="") = Node(tanh, (x,), name)

function addgrad!(y::Var, ::typeof(tanh), x::Var)
Expand Down
4 changes: 2 additions & 2 deletions src/functions/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ C = \alpha \times \textrm{tA}(A) \times \textrm{tB}(B)
function BLAS.gemm(tA::Char, tB::Char, alpha::Number, A::Var, B::Var)
T = eltype(A)
y = BLAS.gemm(tA, tB, T(alpha), A.data, B.data)
Var(y, BLAS.gemm, (tA,tB,alpha,A,B))
Var(y, (BLAS.gemm,tA,tB,alpha,A,B))
end
BLAS.gemm(tA::Char, tB::Char, alpha::Number, A::Node, B::Node; name="") = Node(BLAS.gemm, (tA,tB,alpha,A,B), name)

Expand Down Expand Up @@ -42,7 +42,7 @@ y = \alpha \times \textrm{tA}(A) \times x
function BLAS.gemv(tA::Char, alpha::Number, A::Var, x::Var)
T = eltype(A)
y = BLAS.gemv(tA, T(alpha), A.data, x.data)
Var(y, BLAS.gemv, (tA,alpha,A,x))
Var(y, (BLAS.gemv,tA,alpha,A,x))
end

function addgrad!(y::Var, ::typeof(BLAS.gemv), tA::Char, alpha::Number, A::Var, x::Var)
Expand Down
2 changes: 1 addition & 1 deletion src/functions/concat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ y = concat(2, x1, x2)
"""
function concat(dim::Int, xs::Var...)
y = cat(dim, map(x -> x.data, xs)...)
Var(y, concat, (dim,xs...))
Var(y, (concat,dim,xs...))
end
concat(dim::Int, xs::Node...; name="") = Node(concat, (dim,xs...), name)

Expand Down
2 changes: 1 addition & 1 deletion src/functions/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function (f::Conv1D)(x::Var, batchdims::Vector{Int})
end
h = window1d(f, x.data, batchdims_y)
y = linear(h, f.W.data, f.b.data)
Var(y, f, (x,f.W,f.b,batchdims,h))
Var(y, (f,x,f.W,f.b,batchdims,h))
end
(f::Conv1D)(x::Node, batchdims::Node; name="") = Node(f, (x,batchdims), name)

Expand Down
2 changes: 1 addition & 1 deletion src/functions/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function dropout(x::Var, rate::Float64, train::Bool)
T = eltype(x)
rx = rand(T, length(x.data))
y = dropout(x.data, T(rate), rx)
Var(y, dropout, (x,rate,rx))
Var(y, (dropout,x,rate,rx))
end
end
dropout(x::Node, rate::Float64, train::Node; name="") = Node(dropout, (x,rate,train), name)
Expand Down
2 changes: 1 addition & 1 deletion src/functions/embeddings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ end
function lookup(embeds::Vector{Var}, x::Var)
y = lookup(embeds, x.data)
xs = map(i -> embeds[i], vec(x.data))
Var(y, lookup, (xs,))
Var(y, (lookup,xs))
end
lookup(embeds::Node, x::Node; name="") = Node(lookup, (embeds,x), name)

Expand Down
2 changes: 1 addition & 1 deletion src/functions/getindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ y = x[2:2]
Note that `y = x[i]` throws an error since `y` is not a vector but a scholar.
Instead, use `y = x[i:i]`.
"""
getindex(x::Var, inds::Tuple) = Var(x.data[inds...], getindex, (x,inds))
getindex(x::Var, inds::Tuple) = Var(x.data[inds...], (getindex,x,inds))
getindex(x::Var, inds...) = getindex(x, inds)
getindex(x::Node, inds::Tuple; name="") = Node(getindex, (x,inds), name)

Expand Down
2 changes: 1 addition & 1 deletion src/functions/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end

function linear(x::Var, W::Var, b::Var)
y = linear(x.data, W.data, b.data)
Var(y, linear, (x,W,b))
Var(y, (linear,x,W,b))
end
linear(x::Node, W::Node, b::Node; name="") = Node(linear, (x,W,b), name)

Expand Down
8 changes: 4 additions & 4 deletions src/functions/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ y = crossentropy(p, q)
```
"""
function crossentropy(p::Var, q::Var)
Var(crossentropy(p.data,q.data), crossentropy, (p,q))
Var(crossentropy(p.data,q.data), (crossentropy,p,q))
end
crossentropy(p::Node, q::Node; name="") = Node(crossentropy, (p,q), name)

Expand Down Expand Up @@ -82,7 +82,7 @@ y = l2(x, 0.01)
function l2(x::Var, lambda::Float64)
T = eltype(x)
y = mapreduce(x -> x*x, +, x.data) * T(lambda) / 2
Var([y], l2, (x,lambda))
Var([y], (l2,x,lambda))
end

function addgrad!(y::Var, ::typeof(l2), x::Var, lambda::Float64)
Expand All @@ -104,7 +104,7 @@ The mean is calculated over the minibatch.
Note that the error is not scaled by 1/2.
"""
function mse(x1::Var, x2::Var)
Var(mse(x1.data,x2.data), mse, (x1,x2))
Var(mse(x1.data,x2.data), (mse,x1,x2))
end
mse(x1::Node, x2::Node; name="") = Node(mse, (x1,x2), name)

Expand Down Expand Up @@ -159,7 +159,7 @@ y = softmax_crossentropy(p, x)
"""
function softmax_crossentropy(p::Var, x::Var)
y, logx = softmax_crossentropy(p.data, x.data)
Var(y, softmax_crossentropy, (p,x,logx))
Var(y, (softmax_crossentropy,p,x,logx))
end
softmax_crossentropy(p::Node, x::Node; name="") = Node(softmax_crossentropy, (p,x), name)

Expand Down
34 changes: 11 additions & 23 deletions src/functions/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Base: +, -, *, /, ^
doc"""
exp(x)
"""
exp(x::Var) = Var(exp.(x.data), exp, (x,))
exp(x::Var) = Var(exp.(x.data), (exp,x))
exp(x::Node; name="") = Node(exp, (x,), name)

function addgrad!(y::Var, ::typeof(exp), x::Var)
Expand Down Expand Up @@ -35,7 +35,7 @@ end
doc"""
log(x)
"""
log(x::Var) = Var(log.(x.data), log, (x,))
log(x::Var) = Var(log.(x.data), (log,x))
log(x::Node; name="") = Node(log, (x,), name)

function addgrad!(y::Var, ::typeof(log), x::Var)
Expand Down Expand Up @@ -93,12 +93,9 @@ doc"""
+(a::Number, x::Var)
+(x::Var, a::Number)
"""
+(x1::Var, x2::Var) = Var(x1.data+x2.data, +, (x1,x2))
+(x1::Var, x2::Var) = Var(x1.data+x2.data, (+,x1,x2))
+(x1::Union{Number,Array}, x2::Var) = Var(x1) + x2
+(x1::Var, x2::Union{Number,Array}) = x1 + Var(x2)

+(x1::Node, x2; name="") = Node(+, (x1,x2), name)
+(x1, x2::Node; name="") = Node(+, (x1,x2), name)
+(x1::Node, x2::Node; name="") = Node(+, (x1,x2), name)

function addgrad!(y::Var, ::typeof(+), x1::Var, x2::Var)
Expand All @@ -110,11 +107,10 @@ end
doc"""
-(x1, x2)
"""
-(x1::Var, x2::Var) = Var(x1.data-x2.data, -, (x1,x2))
-(x1::Var, x2::Var) = Var(x1.data-x2.data, (-,x1,x2))
-(a::Number, x::Var) = Var(a) - x
-(x::Var, a::Number) = x - Var(a)
-(x::Var) = Var(-x.data, -, (x,))
-(a::Number, x::Node; name="") = Node(-, (Var(a),x), name)
-(x::Var) = Var(-x.data, (-,x))
-(x1::Node, x2::Node; name="") = Node(-, (x1,x2), name)
function -(x::Node; name="")
Node(-, (x,), name)
Expand All @@ -133,10 +129,8 @@ end
doc"""
.+(x1::Var, x2::Var)
"""
broadcast(::typeof(+), x1::Var, x2::Var) = Var(broadcast(+,x1.data,x2.data), broadcast, (+,x1,x2))
broadcast(::typeof(+), x1::Var, x2::Var) = Var(broadcast(+,x1.data,x2.data), (broadcast,+,x1,x2))
broadcast(::typeof(+), x1::Node, x2::Node; name="") = Node(broadcast, (+,x1,x2), name)
broadcast(::typeof(+), x1::Node, x2::Var; name="") = Node(broadcast, (+,x1,x2), name)
broadcast(::typeof(+), x1::Var, x2::Node; name="") = Node(broadcast, (+,x1,x2), name)

function addgrad!(y::Var, ::typeof(broadcast), ::typeof(+), x1::Var, x2::Var)
isvoid(x1.grad) || ∇broadcast_plus!(y.grad, x1.grad)
Expand All @@ -154,10 +148,8 @@ end
doc"""
.-(x1::Var, x2::Var)
"""
broadcast(::typeof(-), x1::Var, x2::Var) = Var(broadcast(-,x1.data,x2.data), broadcast, (-,x1,x2))
broadcast(::typeof(-), x1::Var, x2::Var) = Var(broadcast(-,x1.data,x2.data), (broadcast,-,x1,x2))
broadcast(::typeof(-), x1::Node, x2::Node; name="") = Node(broadcast, (-,x1,x2), name)
broadcast(::typeof(-), x1::Node, x2::Var; name="") = Node(broadcast, (-,x1,x2), name)
broadcast(::typeof(-), x1::Var, x2::Node; name="") = Node(broadcast, (-,x1,x2), name)

function addgrad!(y::Var, ::typeof(broadcast), ::typeof(-), x1::Var, x2::Var)
isvoid(x1.grad) || ∇broadcast_plus!(y.grad, x1.grad)
Expand All @@ -175,10 +167,8 @@ end
doc"""
\.\*(x1::Var, x2::Var)
"""
broadcast(::typeof(*), x1::Var, x2::Var) = Var(broadcast(*,x1.data,x2.data), broadcast, (*,x1,x2))
broadcast(::typeof(*), x1::Var, x2::Var) = Var(broadcast(*,x1.data,x2.data), (broadcast,*,x1,x2))
broadcast(::typeof(*), x1::Node, x2::Node; name="") = Node(broadcast, (*,x1,x2), name)
broadcast(::typeof(*), x1::Node, x2::Var; name="") = Node(broadcast, (*,x1,x2), name)
broadcast(::typeof(*), x1::Var, x2::Node; name="") = Node(broadcast, (*,x1,x2), name)

function addgrad!(y::Var, ::typeof(broadcast), ::typeof(*), x1::Var, x2::Var)
isvoid(x1.grad) || ∇broadcast_times!(y.grad, x2.data, x1.grad)
Expand Down Expand Up @@ -230,10 +220,8 @@ end
doc"""
\*(A::Var, B::Var)
"""
*(A::Var, B::Var) = Var(*(A.data,B.data), *, (A,B))
*(A::Var, B::Var) = Var(*(A.data,B.data), (*,A,B))
*(A::Node, B::Node; name="") = Node(*, (A,B), name)
*(A::Var, B::Node; name="") = Node(*, (A,B), name)
*(A::Node, B::Var; name="") = Node(*, (A,B), name)

function addgrad!(C::Var, ::typeof(*), A::Var, B::Var)
T = eltype(C)
Expand All @@ -244,7 +232,7 @@ end
doc"""
/(x1::Var, a)
"""
/(x::Var, a::Number) = Var(x.data, /, (x,a))
/(x::Var, a::Number) = Var(x.data, (/,x,a))
/(x::Node, a::Number; name="") = Node(/, (x,a), name)

function addgrad!(y::Var, ::typeof(/), x::Var, a::Number)
Expand All @@ -261,7 +249,7 @@ end
doc"""
^(x::Var, a::Number)
"""
^(x::Var, a::Number) = Var(x.data^a, ^, (x,a))
^(x::Var, a::Number) = Var(x.data^a, (^,x,a))
^(x::Node, a::Number; name="") = Node(^, (x,a), name)

function addgrad!(y::Var, ::typeof(^), x::Var, a::Number)
Expand Down
4 changes: 2 additions & 2 deletions src/functions/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ y = max(x, 1)
"""
function max(x::Var, dim::Int)
y, idx = findmax(x.data, dim)
Var(y, max, (x,idx))
Var(y, (max,x,idx))
end
max(x::Node, dim::Int; name="") = Node(max, (x,dim), name)

Expand All @@ -34,7 +34,7 @@ doc"""
function max_batch(x::Var, dims::Vector{Int})
@assert sum(dims) == size(x)[end]
y, idx = max_batch(x.data, dims)
Var(y, max, (x,idx))
Var(y, (max_batch,x,idx))
end
max_batch(x::Node, dims::Node; name="") = Node(max_batch, (x,dims), name)

Expand Down
2 changes: 1 addition & 1 deletion src/functions/reshape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ x = Var(rand(T,10,5))
y = reshape(x, (2,5), [2,3])
```
"""
reshape(x::Var, dims::Tuple) = Var(reshape(x.data,dims), reshape, (x,))
reshape(x::Var, dims::Tuple) = Var(reshape(x.data,dims), (reshape,x))
reshape(x::Var, dims::Int...) = reshape(x, dims)
reshape(x::Node, dims::Tuple; name="") = Node(reshape, (x,dims), name)

Expand Down
4 changes: 2 additions & 2 deletions src/functions/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Softmax function over the given dimension.
f(x) = \exp(x) \over \sum \exp(x)
```
"""
softmax(x::Var) = Var(softmax(x.data), softmax, (x,))
softmax(x::Var) = Var(softmax(x.data), (softmax,x))
softmax(x::Node, name="") = Node(softmax, (x,), name)

function softmax(x::Vector{T}) where T
Expand Down Expand Up @@ -83,7 +83,7 @@ end
Logarithm of softmax function.
"""
logsoftmax(x::Var) = Var(logsoftmax(x.data), logsoftmax, (x,))
logsoftmax(x::Var) = Var(logsoftmax(x.data), (logsoftmax,x))
logsoftmax(x::Node; name="") = Node(logsoftmax, (x,), name)

function logsoftmax(x::Matrix{T}) where T
Expand Down
2 changes: 1 addition & 1 deletion src/functions/split.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function split(x::Var, dims::Vector{Int})
push!(ys, y)
cumdim += d
end
Var(ys, split, (x,dims))
Var(ys, (split,x,dims))
end

#=
Expand Down
4 changes: 2 additions & 2 deletions src/functions/standardize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ function standardize(x::Var, train::Bool, scale::Var, bias::Var, runmean, runvar
invstd = T(1) ./ sqrt.(xvar + T(eps))
xhat = (x.data .- xmean) .* invstd
data = xhat .* scale.data .+ bias.data
Var(data, standardize, (x,scale,bias,invstd,xhat))
Var(data, (standardize,x,scale,bias,invstd,xhat))
else
data = (x.data .- runmean) ./ sqrt.(runvar + T(eps)) .* scale.data .+ bias.data
Var(data, standardize, (x,scale,bias))
Var(data, (standardize,x,scale,bias))
end
end
standardize(x::Node, train, scale, bias, runmean, runvar; name="") = Node(standardize, (x,train,scale,bias,runmean,runvar), name)
Expand Down
7 changes: 3 additions & 4 deletions src/var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ Variable struct.
"""
mutable struct Var
data
f
args
grad
end

function Var(data, f=nothing, args=(); grad=nothing)
Var(data, f, args, grad)
function Var(data, args=(); grad=nothing)
Var(data, args, grad)
end
zerograd(data) = Var(data, grad=zeros(data))

Expand Down Expand Up @@ -81,7 +80,7 @@ function gradient!(top::Var)
for i = length(sorted):-1:1
v = sorted[i]
isvoid(v.grad) && continue
isvoid(v.f) || addgrad!(v, v.f, v.args...)
isempty(v.args) || addgrad!(v, v.args...)
end
filter(isparam, sorted)
end
Expand Down

0 comments on commit ce29ef1

Please sign in to comment.