-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KA.jl-related slowdowns #565
Comments
Hmm, this is a little disconcerting: Even a very simple Cartesian kernel inhibits a very significant slowdown. using CUDA, KernelAbstractions, Chairmarks
function main()
a = CuArray{Float32}(undef, 512, 1000)
bc = Broadcast.broadcasted(identity, 0f0)
bc = Broadcast.instantiate(Broadcast.Broadcasted(bc.f, bc.args, axes(a)))
print("Old: ")
display(@b CUDA.@sync copyto_old!(a, bc))
print("New: ")
display(@b CUDA.@sync copyto_new!(a, bc))
end
@inline function copyto_old!(dest::AbstractArray, bc)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
function broadcast_kernel(dest, bc)
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
@inbounds if i <= length(dest)
I = CartesianIndices(dest)[i]
dest[I] = bc[I]
end
return
end
kernel = @cuda launch=false broadcast_kernel(dest, bc)
config = launch_configuration(kernel.fun)
threads = min(length(dest), config.threads)
blocks = cld(length(dest), threads)
kernel(dest, bc; threads, blocks)
return dest
end
@inline function copyto_new!(dest::AbstractArray, bc)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
@kernel function broadcast_kernel(dest, bc)
I = @index(Global, Cartesian)
@inbounds dest[I] = bc[I]
end
broadcast_kernel(get_backend(dest))(dest, bc; ndrange=size(dest))
return dest
end Without
With
The overhead scales, e.g., using 4k x 4k inputs instead:
Generated code looks pretty bad, with both extra exceptions, branches, and argument mangling: define ptx_kernel void @old({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x float], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
%.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %1, 0, 0
%2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%3 = zext i32 %2 to i64
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = zext i32 %4 to i64
%6 = mul nuw nsw i64 %3, %5
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = add nuw nsw i32 %7, 1
%9 = zext i32 %8 to i64
%10 = add nuw nsw i64 %6, %9
%.not = icmp sgt i64 %10, %.fca.3.extract
br i1 %.not, label %L165, label %L30
L30: ; preds = %conversion
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
%.not6 = icmp eq i64 %.fca.2.0.extract, 0
br i1 %.not6, label %fail, label %pass
L165: ; preds = %pass, %conversion
ret void
fail: ; preds = %L30
call fastcc void @gpu_report_exception({ i64, i32 } %state, i64 ptrtoint ([10 x i8]* @exception16 to i64))
call fastcc void @gpu_signal_exception({ i64, i32 } %state)
call void @llvm.trap()
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachable
pass: ; preds = %L30
%11 = add nsw i64 %10, -1
%12 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%13 = getelementptr inbounds float, float addrspace(1)* %12, i64 %11
store float %.fca.0.0.extract, float addrspace(1)* %13, align 4
br label %L165
} define ptx_kernel void @new({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x float], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
%.fca.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 0, 0
%.fca.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 1, 0
%.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 0, 0
%.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 0, 0
%.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 1, 0
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %2, 0, 0
%.not = icmp eq i64 %.fca.1.0.0.0.0.extract, 0
br i1 %.not, label %fail, label %pass
L527: ; preds = %pass6, %pass2
ret void
fail: ; preds = %conversion
call fastcc void @gpu_report_exception({ i64, i32 } %state, i64 ptrtoint ([10 x i8]* @exception19 to i64))
call fastcc void @gpu_signal_exception({ i64, i32 } %state)
call void @llvm.trap()
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachable
pass: ; preds = %conversion
%.not15 = icmp eq i64 %.fca.1.1.0.0.0.extract, 0
br i1 %.not15, label %fail1, label %pass2
fail1: ; preds = %pass
call fastcc void @gpu_report_exception({ i64, i32 } %state, i64 ptrtoint ([10 x i8]* @exception19 to i64))
call fastcc void @gpu_signal_exception({ i64, i32 } %state)
call void @llvm.trap()
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachable
pass2: ; preds = %pass
%3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%4 = zext i32 %3 to i64
%5 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%6 = zext i32 %5 to i64
%7 = sdiv i64 %6, %.fca.1.0.0.0.0.extract
%.neg16 = mul i64 %7, %.fca.1.0.0.0.0.extract
%8 = sdiv i64 %4, %.fca.1.1.0.0.0.extract
%9 = add nsw i64 %8, 1
%reass.add17 = add i64 %.neg16, %8
%reass.add = sub i64 %6, %reass.add17
%reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
%10 = add nuw nsw i64 %4, 1
%11 = add i64 %10, %reass.mul
%12 = mul i64 %7, %.fca.1.1.0.1.0.extract
%13 = add i64 %9, %12
%14 = icmp sgt i64 %11, 0
%15 = icmp sle i64 %11, %.fca.0.0.0.0.extract
%16 = and i1 %14, %15
%17 = icmp sgt i64 %13, 0
%18 = icmp sle i64 %13, %.fca.0.0.1.0.extract
%19 = and i1 %17, %18
%20 = and i1 %19, %16
br i1 %20, label %pass6, label %L527
pass6: ; preds = %pass2
%21 = add i64 %12, %8
%22 = mul i64 %21, %.fca.2.0.extract
%23 = add i64 %22, %4
%24 = add i64 %23, %reass.mul
%25 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%26 = getelementptr inbounds float, float addrspace(1)* %25, i64 %24
store float %.fca.0.0.extract, float addrspace(1)* %26, align 4
br label %L527
} This results in a much higher register usage, from 28 to 64 (at the PTX level). cc @vchuravy |
Looks like most of the added code comes from KA's nditeration handlng: pass2: ; preds = %pass
; │└└└└└└└└└
; │┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:92 within `#threadIdx`
; ││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:46 within `threadIdx_x`
; │││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `_index`
; ││││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `macro expansion` @ /home/tim/.julia/packages/LLVM/joxPv/src/interop/base.jl:39
%11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
; │└└└└
; │┌ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:84 within `expand`
; ││┌ @ abstractarray.jl:1312 within `getindex`
; │││┌ @ abstractarray.jl:1353 within `_getindex`
; ││││┌ @ abstractarray.jl:1360 within `_to_subscript_indices`
; │││││┌ @ abstractarray.jl:1382 within `_unsafe_ind2sub`
; ││││││┌ @ abstractarray.jl:3053 within `_ind2sub` @ abstractarray.jl:3091
; │││││││┌ @ int.jl:86 within `-`
%12 = zext i32 %11 to i64
; │└└└└└└└
; │┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:78 within `#blockIdx`
; ││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:56 within `blockIdx_x`
; │││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `_index`
; ││││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `macro expansion` @ /home/tim/.julia/packages/LLVM/joxPv/src/interop/base.jl:39
%13 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
; │└└└└
; │┌ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:84 within `expand`
; ││┌ @ abstractarray.jl:1312 within `getindex`
; │││┌ @ abstractarray.jl:1353 within `_getindex`
; ││││┌ @ abstractarray.jl:1360 within `_to_subscript_indices`
; │││││┌ @ abstractarray.jl:1382 within `_unsafe_ind2sub`
; ││││││┌ @ abstractarray.jl:3053 within `_ind2sub` @ abstractarray.jl:3091
; │││││││┌ @ int.jl:86 within `-`
%14 = zext i32 %13 to i64
; │││││││└
; │││││││┌ @ abstractarray.jl:3104 within `_ind2sub_recurse`
; ││││││││┌ @ abstractarray.jl:3111 within `_div`
; │││││││││┌ @ int.jl:295 within `div`
%15 = sdiv i64 %14, %.fca.1.0.0.0.0.extract
%.neg17 = mul i64 %15, %.fca.1.0.0.0.0.extract
%16 = sdiv i64 %12, %.fca.1.1.0.0.0.extract
; ││││││││└└
; ││││││││ @ abstractarray.jl:3105 within `_ind2sub_recurse` @ abstractarray.jl:3099
; ││││││││┌ @ abstractarray.jl:3109 within `_lookup`
; │││││││││┌ @ int.jl:87 within `+`
%17 = add nsw i64 %16, 1
%reass.add18 = add i64 %.neg17, %16
%reass.add = sub i64 %14, %reass.add18
%reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
; ││││││││└└
; ││││││││ @ abstractarray.jl:3105 within `_ind2sub_recurse`
; ││││││││┌ @ int.jl:87 within `+`
%18 = add nuw nsw i64 %12, 1
; ││└└└└└└└
; ││ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:84 within `expand` @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:74
; ││┌ @ ntuple.jl:49 within `ntuple`
; │││┌ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:78 within `#1`
; ││││┌ @ int.jl:87 within `+`
%19 = add i64 %18, %reass.mul
; ││││└
; ││││┌ @ int.jl:88 within `*`
%20 = mul i64 %15, %.fca.1.1.0.1.0.extract
; ││││└
; ││││┌ @ int.jl:87 within `+`
%21 = add i64 %17, %20
; │└└└└
; │ @ /home/tim/Julia/pkg/CUDA/src/CUDAKernels.jl:168 within `#__validindex`
; │┌ @ multidimensional.jl:477 within `in`
; ││┌ @ tuple.jl:383 within `map`
; │││┌ @ range.jl:1426 within `in`
; ││││┌ @ int.jl:514 within `<=`
%22 = icmp sgt i64 %19, 0
%23 = icmp sle i64 %19, %.fca.0.0.0.0.extract
; ││││└
; ││││┌ @ bool.jl:38 within `&`
%24 = and i1 %22, %23
; ││││└
; ││││┌ @ int.jl:514 within `<=`
%25 = icmp sgt i64 %21, 0
%26 = icmp sle i64 %21, %.fca.0.0.1.0.extract
; ││││└
; ││││┌ @ bool.jl:38 within `&`
%27 = and i1 %25, %26
; ││└└└
; ││┌ @ tuple.jl:664 within `all`
; │││┌ @ bool.jl:38 within `&`
%28 = and i1 %27, %24
; └└└└
br i1 %28, label %L242, label %L538 A couple of other things that stand out:
|
Testing on JuliaGPU/KernelAbstractions.jl#518, a bit of performance is recovered, but it remains bad:
Updated MWE: using CUDA, KernelAbstractions, Chairmarks
using LLVM, LLVM.Interop
function main()
a = CuArray{Float32}(undef, 4000, 4000)
bc = Broadcast.broadcasted(identity, 0f0)
bc = Broadcast.instantiate(Broadcast.Broadcasted(bc.f, bc.args, axes(a)))
print("Old: ")
display(@b CUDA.@sync copyto_old!(a, bc))
print("New: ")
display(@b CUDA.@sync copyto_new!(a, bc))
end
@inline function copyto_old!(dest::AbstractArray, bc)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
function broadcast_kernel(dest, bc)
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
assume.(size(dest) .> 0)
@inbounds if i <= length(dest)
I = CartesianIndices(dest)[i]
dest[I] = bc[I]
end
return
end
kernel = @cuda launch=false broadcast_kernel(dest, bc)
config = launch_configuration(kernel.fun)
threads = min(length(dest), config.threads)
blocks = cld(length(dest), threads)
kernel(dest, bc; threads, blocks)
return dest
end
@inline function copyto_new!(dest::AbstractArray, bc)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
@kernel function broadcast_kernel(dest, bc)
I = @index(Global, Cartesian)
@inbounds dest[I] = bc[I]
end
broadcast_kernel(get_backend(dest))(dest, bc; ndrange=size(dest))
return dest
end Old: define ptx_kernel void @_Z16broadcast_kernel13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI17DefaultArrayStyleILi0EE5TupleI5OneToI5Int64ES8_E8identityS5_IS0_EE({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x float], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
%.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 1
%.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
%2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%3 = zext i32 %2 to i64
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = zext i32 %4 to i64
%6 = mul nuw nsw i64 %3, %5
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = add nuw nsw i32 %7, 1
%9 = zext i32 %8 to i64
%10 = add nuw nsw i64 %6, %9
%11 = icmp sgt i64 %.fca.2.0.extract, 0
call void @llvm.assume(i1 %11)
%12 = icmp sgt i64 %.fca.2.1.extract, 0
call void @llvm.assume(i1 %12)
%.not = icmp sgt i64 %10, %.fca.3.extract
br i1 %.not, label %L176, label %pass
L176: ; preds = %pass, %conversion
ret void
pass: ; preds = %conversion
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
%13 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%14 = add nsw i64 %10, -1
%15 = getelementptr inbounds float, float addrspace(1)* %13, i64 %14
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %1, 0, 0
store float %.fca.0.0.extract, float addrspace(1)* %15, align 4
br label %L176
} New:
The added arithmetic instructions are very clear now, and the It's curious that the old version doesn't have any integer divisions to convert the linear index into a global one. Presumably this is possible because the array style of the broadcast is 0D (we're broadcasting a simple scalar), and when using KA.jl that somehow gets lost. Indeed, when broadcasting an actual array (where the old code would have to function main()
a = CuArray{Float32}(undef, 4000, 4000)
b = CuArray{Float32}(undef, 4000, 4000)
bc = Broadcast.broadcasted(identity, b)
bc = Broadcast.instantiate(Broadcast.Broadcasted(bc.f, bc.args, axes(a)))
print("Old: ")
display(@b CUDA.@sync copyto_old!(a, bc))
print("New: ")
display(@b CUDA.@sync copyto_new!(a, bc))
end
Old: define ptx_kernel void @_Z16broadcast_kernel13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI12CuArrayStyleILi2E12DeviceMemoryE5TupleI5OneToI5Int64ES9_E8identityS6_I8ExtrudedIS1_S6_I4BoolSD_ES6_IS8_S8_EEEE({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
%.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 1
%.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
%2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%3 = zext i32 %2 to i64
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = zext i32 %4 to i64
%6 = mul nuw nsw i64 %3, %5
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = add nuw nsw i32 %7, 1
%9 = zext i32 %8 to i64
%10 = add nuw nsw i64 %6, %9
%11 = icmp sgt i64 %.fca.2.0.extract, 0
call void @llvm.assume(i1 %11)
%12 = icmp sgt i64 %.fca.2.1.extract, 0
call void @llvm.assume(i1 %12)
%.not = icmp sgt i64 %10, %.fca.3.extract
br i1 %.not, label %L221, label %pass
L221: ; preds = %pass, %conversion
ret void
pass: ; preds = %conversion
%.fca.0.0.2.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 2, 1
%.fca.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 2, 0
%.fca.0.0.1.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 1, 1
%.fca.0.0.1.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 1, 0
%.fca.0.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 0, 2, 0
%.fca.0.0.0.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 0, 0
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
%13 = add nsw i64 %10, -1
%14 = udiv i64 %13, %.fca.2.0.extract
%15 = mul i64 %14, %.fca.2.0.extract
%16 = sub i64 %13, %15
%17 = add i64 %16, 1
%18 = add nuw nsw i64 %14, 1
%19 = and i8 %.fca.0.0.1.0.extract, 1
%.not7 = icmp eq i8 %19, 0
%20 = select i1 %.not7, i64 %.fca.0.0.2.0.extract, i64 %17
%21 = and i8 %.fca.0.0.1.1.extract, 1
%.not8 = icmp eq i8 %21, 0
%22 = select i1 %.not8, i64 %.fca.0.0.2.1.extract, i64 %18
%23 = add i64 %22, -1
%24 = mul i64 %23, %.fca.0.0.0.2.0.extract
%25 = add i64 %24, -1
%26 = add i64 %25, %20
%27 = bitcast i8 addrspace(1)* %.fca.0.0.0.0.extract to float addrspace(1)*
%28 = getelementptr inbounds float, float addrspace(1)* %27, i64 %26
%29 = load float, float addrspace(1)* %28, align 4
%30 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%31 = getelementptr inbounds float, float addrspace(1)* %30, i64 %13
store float %29, float addrspace(1)* %31, align 4
br label %L221
} Notice the KA.jl: define ptx_kernel void @_Z20gpu_broadcast_kernel16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi2E5TupleI5OneToI5Int64ES6_EE7NDRangeILi2ES0_S0_S8_S8_EE13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI12CuArrayStyleILi2E12DeviceMemoryES7_8identityS3_I8ExtrudedISE_S3_I4BoolSL_ES3_IS5_S5_EEEE({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
%.fca.0.0.0.0.extract6 = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 0, 0
%.fca.0.0.1.0.extract7 = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 1, 0
%.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 0, 0
%.fca.1.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 1, 0
%.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 0, 0
%.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 1, 0
%3 = icmp sgt i64 %.fca.1.0.0.0.0.extract, 0
call void @llvm.assume(i1 %3)
%4 = icmp sgt i64 %.fca.1.0.0.1.0.extract, 0
call void @llvm.assume(i1 %4)
%5 = icmp sgt i64 %.fca.1.1.0.0.0.extract, 0
call void @llvm.assume(i1 %5)
%6 = icmp sgt i64 %.fca.1.1.0.1.0.extract, 0
call void @llvm.assume(i1 %6)
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = zext i32 %7 to i64
%9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%10 = zext i32 %9 to i64
%11 = udiv i64 %10, %.fca.1.0.0.0.0.extract
%.neg28 = mul i64 %11, %.fca.1.0.0.0.0.extract
%12 = udiv i64 %8, %.fca.1.1.0.0.0.extract
%13 = add nuw nsw i64 %12, 1
%reass.add29 = add i64 %.neg28, %12
%reass.add = sub i64 %10, %reass.add29
%reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
%14 = add nuw nsw i64 %8, 1
%15 = add i64 %14, %reass.mul
%16 = mul i64 %11, %.fca.1.1.0.1.0.extract
%17 = add i64 %13, %16
%18 = icmp sgt i64 %15, 0
%19 = icmp sle i64 %15, %.fca.0.0.0.0.extract6
%20 = and i1 %18, %19
%21 = icmp sgt i64 %17, 0
%22 = icmp sle i64 %17, %.fca.0.0.1.0.extract7
%23 = and i1 %21, %22
%24 = and i1 %23, %20
br i1 %24, label %pass6, label %L630
L630: ; preds = %pass6, %conversion
ret void
pass6: ; preds = %conversion
%.fca.0.0.2.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 2, 1
%.fca.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 2, 0
%.fca.0.0.1.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 1, 1
%.fca.0.0.1.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 1, 0
%.fca.0.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 0, 2, 0
%.fca.0.0.0.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 0, 0
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
%25 = sub i64 %10, %.neg28
%26 = mul i64 %12, %.fca.1.1.0.0.0.extract
%27 = sub i64 %8, %26
%28 = mul i64 %25, %.fca.1.1.0.0.0.extract
%29 = add i64 %28, %27
%30 = add i64 %29, 1
%31 = add i64 %16, %12
%32 = add i64 %31, 1
%33 = and i8 %.fca.0.0.1.0.extract, 1
%.not = icmp eq i8 %33, 0
%34 = select i1 %.not, i64 %.fca.0.0.2.0.extract, i64 %30
%35 = and i8 %.fca.0.0.1.1.extract, 1
%.not27 = icmp eq i8 %35, 0
%36 = select i1 %.not27, i64 %.fca.0.0.2.1.extract, i64 %32
%37 = add i64 %36, -1
%38 = mul i64 %37, %.fca.0.0.0.2.0.extract
%39 = add i64 %38, -1
%40 = add i64 %39, %34
%41 = bitcast i8 addrspace(1)* %.fca.0.0.0.0.extract to float addrspace(1)*
%42 = getelementptr inbounds float, float addrspace(1)* %41, i64 %40
%43 = load float, float addrspace(1)* %42, align 4
%44 = mul i64 %31, %.fca.2.0.extract
%45 = add i64 %29, %44
%46 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%47 = getelementptr inbounds float, float addrspace(1)* %46, i64 %45
store float %43, float addrspace(1)* %47, align 4
br label %L630
} Still a lot more code, and and additional |
JuliaGPU/KernelAbstractions.jl#539 results in only a single |
With JuliaGPU/KernelAbstractions.jl#539, performance is interestingly worse. For the scalar broadcast:
(as opposed to 61us on master, and 56us on JuliaGPU/KernelAbstractions.jl#518) For a 2D broadcast:
... which is again a bit slower than before. Adding some
|
To summarize, there seems to be (at least) three areas to improve / optimize:
These are all visible in the MWE from #565 (comment) |
The switch to KA.jl significantly slowed down several operations.
CUDA.jl:
permudetims
,broadcast
, and many othershttps://speed.juliagpu.org/changes/?tre=10&rev=6221589f5befec8f6f157a5a5271667dba09d0b6&exe=11&env=1
Metal.jl:
permudetims
The text was updated successfully, but these errors were encountered: