Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect derivative from bug in jitrules.jl: nonzero dval added twice #2182

Open
danielwe opened this issue Dec 6, 2024 · 2 comments
Open

Comments

@danielwe
Copy link
Contributor

danielwe commented Dec 6, 2024

Certain patterns of splatting/tuple concatenation along the lines of test/mixedapplyiter.jl lead to incorrect derivatives when a Duplicated's dval is nonzero. Instead of dval .+= dret, you effectively get dval .+= dval .+ dret. The relevant tests always use dval .= 0, concealing the bug.

MWE:

using Enzyme

concat() = ()
concat(a) = a
concat(a, b) = (a..., b...)
concat(a, b, c...) = concat(concat(a, b), c...)

function f(x)
    t = concat(x...)  # fails
    # t = x[1]        # works
    return t[1] * t[1] + t[2][1] * t[2][1]
end

x = [(2.0, [2.7])]
dx = [(1.0, [1.0])]
res = Enzyme.autodiff(Reverse, f, Active, Duplicated(x, dx))
dx_true = [(5.0, [6.4])]
@show dx
@show dx_true
julia> include("bug.jl");
dx = [(6.0, [6.4])]
dx_true = [(5.0, [6.4])]

The double accumulation happens in the following function:

function add_into_vec!(val::T, expr, vec, idx_in_vec) where {T}
if ismutable(vec)
@inbounds vec[idx_in_vec] = recursive_add(val, expr, identity, guaranteed_nonactive)
else
error(
"Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec",
)
end
nothing
end

When the MWE above reaches this call, vec == dx, val == dx[1], and expr == dx_true[1]. That is, expr already contains the correct final value of dx[1] including the contribution from the initial values in dx, so applying recursive_add to add them up again is incorrect. (The array-wrapped value is not affected by the bug because guaranteed_nonactive is passed to recursive_add, stopping the recursion at DupState values.)

I've verified that the MWE is fixed if I change line 1397 to

@inbounds vec[idx_in_vec] = expr

The entire test suite also passes after this change. However, while I haven't fully wrapped my head around jitrules.jl, it doesn't look like this would be the proper solution.

@danielwe
Copy link
Contributor Author

danielwe commented Dec 7, 2024

Simpler MWE:

using Enzyme

concat() = ()
concat(a) = a
concat(a, b) = (a..., b...)
concat(a, b, c...) = concat(concat(a, b), c...)

f(x) = first(concat(x...))

x = [(2.0, [2.7])]
dx = [(1.0, [1.0])]
res = Enzyme.autodiff(Reverse, f, Active, Duplicated(x, dx))
dx_true = [(2.0, [1.0])]
@show dx
@show dx_true

The problem appears in the runtime handling of the type unstable call to concat in the reverse pass. Can confirm that just setting vec[idx_in_vec] = expr is not the solution; this gives incorrect answers in cases where the arguments to concat are Active instead of MixedDuplicated (apparently no tests hit this functionality at the moment).

I wonder if the error is that the following code handles active parts of MixedDuplicated args like active args?

@inbounds endexprs[i, w] = if args[i] <: Active || args[i] <: MixedDuplicated || args[i] <: BatchMixedDuplicated
expr = if args[i] <: Active || f <: typeof(Base.tuple)
if width == 1
:(tup[$i])
else
:(tup[$i][$w])
end
elseif args[i] <: MixedDuplicated
:(args[$i].dval[])
else
:(args[$i].dval[$w][])
end
quote
idx_of_vec, idx_in_vec = $(lengths[i])
vec = @inbounds shadowargs[idx_of_vec][$w]
if vec isa Base.RefValue
vecld = vec[]
T = Core.Typeof(vecld)
@assert !(vecld isa Base.RefValue)
vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), $expr)
else
val = @inbounds vec[idx_in_vec]
add_into_vec!(Base.inferencebarrier(val), $expr, vec, idx_in_vec)
end
end

I get that the idea is to accumulate shadows for active args, including active parts of MixedDuplicated, but at least in the MWE it seems like MixedDuplicated shadows have already been handled in-place by the time we get here, all that remains is to assign them to vec[idx_in_vec]. Could the solution look something like this?

            @inbounds endexprs[i, w] = if args[i] <: Active || args[i] <: MixedDuplicated || args[i] <: BatchMixedDuplicated
                if args[i] <: Active || f <: typeof(Base.tuple)
                    expr = if width == 1
                        :(tup[$i])
                    else
                        :(tup[$i][$w])
                    end
                elseif args[i] <: MixedDuplicated
                    :(args[$i].dval[])
                else
                    :(args[$i].dval[$w][])
                end

                quote
                    idx_of_vec, idx_in_vec = $(lengths[i])
                    vec = @inbounds shadowargs[idx_of_vec][$w]
                    if vec isa Base.RefValue
                        vecld = vec[]
                        T = Core.Typeof(vecld)
                        @assert !(vecld isa Base.RefValue)
                        vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), $expr)
                    elseif $(args[i] <: Active || f <: typeof(Base.tuple))
                        val = @inbounds vec[idx_in_vec]
                        add_into_vec!(Base.inferencebarrier(val), $expr, vec, idx_in_vec)
                    else
                        @inbounds vec[idx_in_vec] = expr
                    end
                end

(Don't know if something would also need to be done wrt. the Base.RefValue branch?)

@wsmoses
Copy link
Member

wsmoses commented Dec 7, 2024

Yeah you're absolutely right I don't think we should update anything for the mixedduplicated here. Make a PR (and definitely add the test)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants