Skip to content

Commit

Permalink
Support jl_genericmemory_copyto
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 4, 2025
1 parent 4b97d8e commit e5ecccd
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,74 @@ end
return nothing
end

@register_fwd function genericmemory_copyto_fwd(B, orig, gutils, normalR, shadowR)
if is_constant_inst(gutils, orig)
return true
end
origops = collect(operands(orig))
width = get_width(gutils)
origops = collect(operands(orig))
width = get_width(gutils)

legal, dest_ty, _ = abs_typeof(origops[1])

if !legal
emit_error(B, orig, "Enzyme: could not deduce element type of value within generic_memory_copyto of " * string(origops[1]) * " within "*string(orig))
else
dest_ty = Vector{Any}
end

ET = eltype(dest_ty)

reg = active_reg_inner(Ty, (), world)
if reg == ActiveState || reg == MixedState
emit_error(B, orig, "Enzyme: element type $ET of generic_memory_copyto is potentially active ($reg) and not presently supported")
end

args = LLVM.Value[]
for a in origops[1:end-2]
v = invert_pointer(gutils, a, B)
push!(args, v)
end
push!(args, new_from_original(gutils, origops[end-1]))
valTys = API.CValueType[
API.VT_Shadow,
API.VT_Shadow,
API.VT_Shadow,
API.VT_Shadow,
API.VT_Primal,
]

if width == 1
vargs = args
cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, false) #=lookup=#
debug_from_orig!(gutils, cal, orig)
callconv!(cal, callconv(orig))
else
shadowres =
UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))))
for idx = 1:width
vargs = LLVM.Value[]
for a in args[1:end-1]
push!(vargs, extract_value!(B, a, idx - 1))
end
push!(vargs, args[end])
cal =
call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, false) #=lookup=#
debug_from_orig!(gutils, cal, orig)
callconv!(cal, callconv(orig))
end
end

return false
end
@register_aug function genericmemory_copyto_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
genericmemory_copyto_fwd(B, orig, gutils, normalR, shadowR)
end
@register_rev function genericmemory_copyto_rev(B, orig, gutils, tape)
return nothing
end

@register_fwd function jl_array_sizehint_fwd(B, orig, gutils, normalR, shadowR)
origops = collect(operands(orig))
if is_constant_value(gutils, origops[1])
Expand Down Expand Up @@ -2193,6 +2261,12 @@ end
@revfunc(genericmemory_copy_slice_rev),
@fwdfunc(genericmemory_copy_slice_fwd),
)
register_handler!(
("jl_genericmemory_copyto", "ijl_genericmemory_copyto"),
@augfunc(genericmemory_copyto_augfwd),
@revfunc(genericmemory_copyto_rev),
@fwdfunc(genericmemory_copyto_fwd),
)
register_handler!(
("jl_genericmemory_slice", "ijl_genericmemory_slice"),
@augfunc(genericmemory_slice_augfwd),
Expand Down

0 comments on commit e5ecccd

Please sign in to comment.