Skip to content

Commit

Permalink
Update linear
Browse files Browse the repository at this point in the history
  • Loading branch information
hshindo committed Dec 11, 2017
1 parent de29ab9 commit df538fa
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/functions/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function BLAS.gemm(tA::Char, tB::Char, alpha::Number, A::Var, B::Var)
y = BLAS.gemm(tA, tB, T(alpha), A.data, B.data)
Var(y, BLAS.gemm, (tA,tB,alpha,A,B))
end
BLAS.gemm(tA::Char, tB::Char, alpha::Number, A::Node, B::Node; name="") = Node(BLAS.gemm, (tA,tB,alpha,A,B), name)

function addgrad!(C::Var, ::typeof(BLAS.gemm), tA::Char, tB::Char, alpha::Number, A::Var, B::Var)
T = eltype(C.data)
Expand Down
4 changes: 2 additions & 2 deletions src/functions/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ function linear(x::Var, W::Var, b::Var)
end
linear(x::Node, W::Var, b::Var; name="") = Node(linear, (x,W,b), name)

function linear(x::Matrix, W::Matrix, b::Vector)
function linear(x::Matrix, W::Matrix, b)
y = BLAS.gemm('T', 'N', W, x)
broadcast!(+, y, y, b)
b == nothing || broadcast!(+, y, y, b)
y
end

Expand Down

0 comments on commit df538fa

Please sign in to comment.