-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect derivative from bug in jitrules.jl: nonzero dval
added twice
#2182
Comments
Simpler MWE: using Enzyme
concat() = ()
concat(a) = a
concat(a, b) = (a..., b...)
concat(a, b, c...) = concat(concat(a, b), c...)
f(x) = first(concat(x...))
x = [(2.0, [2.7])]
dx = [(1.0, [1.0])]
res = Enzyme.autodiff(Reverse, f, Active, Duplicated(x, dx))
dx_true = [(2.0, [1.0])]
@show dx
@show dx_true The problem appears in the runtime handling of the type unstable call to I wonder if the error is that the following code handles active parts of Enzyme.jl/src/rules/jitrules.jl Lines 1454 to 1479 in 3b36ea2
I get that the idea is to accumulate shadows for active args, including active parts of MixedDuplicated, but at least in the MWE it seems like MixedDuplicated shadows have already been handled in-place by the time we get here, all that remains is to assign them to @inbounds endexprs[i, w] = if args[i] <: Active || args[i] <: MixedDuplicated || args[i] <: BatchMixedDuplicated
if args[i] <: Active || f <: typeof(Base.tuple)
expr = if width == 1
:(tup[$i])
else
:(tup[$i][$w])
end
elseif args[i] <: MixedDuplicated
:(args[$i].dval[])
else
:(args[$i].dval[$w][])
end
quote
idx_of_vec, idx_in_vec = $(lengths[i])
vec = @inbounds shadowargs[idx_of_vec][$w]
if vec isa Base.RefValue
vecld = vec[]
T = Core.Typeof(vecld)
@assert !(vecld isa Base.RefValue)
vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), $expr)
elseif $(args[i] <: Active || f <: typeof(Base.tuple))
val = @inbounds vec[idx_in_vec]
add_into_vec!(Base.inferencebarrier(val), $expr, vec, idx_in_vec)
else
@inbounds vec[idx_in_vec] = expr
end
end (Don't know if something would also need to be done wrt. the |
Yeah you're absolutely right I don't think we should update anything for the mixedduplicated here. Make a PR (and definitely add the test)? |
Certain patterns of splatting/tuple concatenation along the lines of
test/mixedapplyiter.jl
lead to incorrect derivatives when a Duplicated'sdval
is nonzero. Instead ofdval .+= dret
, you effectively getdval .+= dval .+ dret
. The relevant tests always usedval .= 0
, concealing the bug.MWE:
The double accumulation happens in the following function:
Enzyme.jl/src/rules/jitrules.jl
Lines 1395 to 1404 in 3b36ea2
When the MWE above reaches this call,
vec == dx
,val == dx[1]
, andexpr == dx_true[1]
. That is,expr
already contains the correct final value ofdx[1]
including the contribution from the initial values indx
, so applyingrecursive_add
to add them up again is incorrect. (The array-wrapped value is not affected by the bug becauseguaranteed_nonactive
is passed torecursive_add
, stopping the recursion atDupState
values.)I've verified that the MWE is fixed if I change line 1397 to
The entire test suite also passes after this change. However, while I haven't fully wrapped my head around
jitrules.jl
, it doesn't look like this would be the proper solution.The text was updated successfully, but these errors were encountered: