diff --git a/src/gradcheck.lua b/src/gradcheck.lua index 4e31780..7dc8cf5 100644 --- a/src/gradcheck.lua +++ b/src/gradcheck.lua @@ -1,5 +1,6 @@ -- Autograd local autograd = require 'autograd' +local util = require 'autograd.util' -- Perturbation (finite diffs): local perturbation = 1e-6 @@ -12,20 +13,30 @@ local function jacobianFromAutograd(func, inputs, key) -- Autograd: local df = autograd(func) local grads = df(table.unpack(inputs)) - local gradsVerify = df(table.unpack(inputs)) -- Find grad: local g = autograd.util.nestedGet(grads, key) + local g_clone + if torch.isTensor(g) then + g_clone = g:clone() + end + + -- Get the grad again + local gradsVerify = df(table.unpack(inputs)) local gVerify = autograd.util.nestedGet(gradsVerify, key) local err + local overwrite_err = 0 if torch.isTensor(g) then err = (g - gVerify):abs():max() + overwrite_err = (g - g_clone):abs():max() else err = torch.abs(g - gVerify) end if err ~= 0 then error("autograd gradient not deterministic") + elseif overwrite_err ~= 0 then + error("autograd gradient overwritten when called twice") end -- Return grads: diff --git a/test/test.lua b/test/test.lua index 5adc28f..8c8e8eb 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1688,8 +1688,9 @@ local tests = { end tester:assert(gradcheck(f4,{x=torch.randn(10,10),y=torch.randn(3)}), "Incorrect gradient") local f5 = function(params) - params.x[2] = params.y*2.0 - return torch.sum(params.x) + local xc = torch.clone(params.x) + xc[2] = params.y * 2.0 + return torch.sum(xc) end tester:assert(gradcheck(f5,{x=torch.randn(10,10),y=torch.randn(10)}), "Incorrect gradient") end,