Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KA.jl-related slowdowns #565

Open
maleadt opened this issue Oct 18, 2024 · 6 comments
Open

KA.jl-related slowdowns #565

maleadt opened this issue Oct 18, 2024 · 6 comments

Comments

@maleadt
Copy link
Member

maleadt commented Oct 18, 2024

The switch to KA.jl significantly slowed down several operations.


CUDA.jl: permudetims, broadcast, and many others

https://speed.juliagpu.org/changes/?tre=10&rev=6221589f5befec8f6f157a5a5271667dba09d0b6&exe=11&env=1


Metal.jl: permudetims

private array/permutedims/4d 	2911500 ns 	860084 ns 	3.39
private array/permutedims/2d 	1065021 ns 	862229.5 ns 	1.24
private array/permutedims/3d 	1629229 ns 	919520.5 ns 	1.77

shared array/permutedims/4d 	2933000 ns 	858875 ns 	3.41
shared array/permutedims/2d 	1054250 ns 	862292 ns 	1.22
shared array/permutedims/3d 	1625958 ns 	923916.5 ns 	1.76
@maleadt
Copy link
Member Author

maleadt commented Oct 18, 2024

Hmm, this is a little disconcerting: Even a very simple Cartesian kernel inhibits a very significant slowdown.
Reduced from our broadcast implementation:

using CUDA, KernelAbstractions, Chairmarks

function main()
    a = CuArray{Float32}(undef, 512, 1000)
    bc = Broadcast.broadcasted(identity, 0f0)
    bc = Broadcast.instantiate(Broadcast.Broadcasted(bc.f, bc.args, axes(a)))
    print("Old: ")
    display(@b CUDA.@sync copyto_old!(a, bc))
    print("New: ")
    display(@b CUDA.@sync copyto_new!(a, bc))
end

@inline function copyto_old!(dest::AbstractArray, bc)
    axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
    isempty(dest) && return dest
    bc = Broadcast.preprocess(dest, bc)

    function broadcast_kernel(dest, bc)
        i = (blockIdx().x-1) * blockDim().x + threadIdx().x
        @inbounds if i <= length(dest)
            I = CartesianIndices(dest)[i]
            dest[I] = bc[I]
        end
        return
    end

    kernel = @cuda launch=false broadcast_kernel(dest, bc)
    config = launch_configuration(kernel.fun)
    threads = min(length(dest), config.threads)
    blocks = cld(length(dest), threads)
    kernel(dest, bc; threads, blocks)

    return dest
end

@inline function copyto_new!(dest::AbstractArray, bc)
    axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
    isempty(dest) && return dest
    bc = Broadcast.preprocess(dest, bc)

    @kernel function broadcast_kernel(dest, bc)
        I = @index(Global, Cartesian)
        @inbounds dest[I] = bc[I]
    end

    broadcast_kernel(get_backend(dest))(dest, bc; ndrange=size(dest))

    return dest
end

Without CUDA.@sync, i.e., measuring launch overhead:

julia> main()
Old: 3.199 μs (12 allocs: 240 bytes)
New: 5.192 μs (58 allocs: 1.859 KiB)

With CUDA.@sync, i.e., measuring execution time:

julia> main()
Old: 7.746 μs (12 allocs: 240 bytes)
New: 10.940 μs (58 allocs: 1.859 KiB)

The overhead scales, e.g., using 4k x 4k inputs instead:

julia> main()
Old: 30.230 μs (12 allocs: 240 bytes)
New: 61.250 μs (58 allocs: 1.859 KiB)

Generated code looks pretty bad, with both extra exceptions, branches, and argument mangling:

define ptx_kernel void @old({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x float], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
  %.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
  %.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
  %.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %1, 0, 0
  %2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %3 = zext i32 %2 to i64
  %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
  %5 = zext i32 %4 to i64
  %6 = mul nuw nsw i64 %3, %5
  %7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %8 = add nuw nsw i32 %7, 1
  %9 = zext i32 %8 to i64
  %10 = add nuw nsw i64 %6, %9
  %.not = icmp sgt i64 %10, %.fca.3.extract
  br i1 %.not, label %L165, label %L30

L30:                                              ; preds = %conversion
  %.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
  %.not6 = icmp eq i64 %.fca.2.0.extract, 0
  br i1 %.not6, label %fail, label %pass

L165:                                             ; preds = %pass, %conversion
  ret void

fail:                                             ; preds = %L30
  call fastcc void @gpu_report_exception({ i64, i32 } %state, i64 ptrtoint ([10 x i8]* @exception16 to i64))
  call fastcc void @gpu_signal_exception({ i64, i32 } %state)
  call void @llvm.trap()
  call void @llvm.trap()
  call void asm sideeffect "exit;", ""()
  unreachable

pass:                                             ; preds = %L30
  %11 = add nsw i64 %10, -1
  %12 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
  %13 = getelementptr inbounds float, float addrspace(1)* %12, i64 %11
  store float %.fca.0.0.extract, float addrspace(1)* %13, align 4
  br label %L165
}
define ptx_kernel void @new({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x float], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
  %.fca.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 0, 0
  %.fca.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 1, 0
  %.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 0, 0
  %.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 0, 0
  %.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 1, 0
  %.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
  %.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
  %.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %2, 0, 0
  %.not = icmp eq i64 %.fca.1.0.0.0.0.extract, 0
  br i1 %.not, label %fail, label %pass

L527:                                             ; preds = %pass6, %pass2
  ret void

fail:                                             ; preds = %conversion
  call fastcc void @gpu_report_exception({ i64, i32 } %state, i64 ptrtoint ([10 x i8]* @exception19 to i64))
  call fastcc void @gpu_signal_exception({ i64, i32 } %state)
  call void @llvm.trap()
  call void @llvm.trap()
  call void asm sideeffect "exit;", ""()
  unreachable

pass:                                             ; preds = %conversion
  %.not15 = icmp eq i64 %.fca.1.1.0.0.0.extract, 0
  br i1 %.not15, label %fail1, label %pass2

fail1:                                            ; preds = %pass
  call fastcc void @gpu_report_exception({ i64, i32 } %state, i64 ptrtoint ([10 x i8]* @exception19 to i64))
  call fastcc void @gpu_signal_exception({ i64, i32 } %state)
  call void @llvm.trap()
  call void @llvm.trap()
  call void asm sideeffect "exit;", ""()
  unreachable

pass2:                                            ; preds = %pass
  %3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %4 = zext i32 %3 to i64
  %5 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %6 = zext i32 %5 to i64
  %7 = sdiv i64 %6, %.fca.1.0.0.0.0.extract
  %.neg16 = mul i64 %7, %.fca.1.0.0.0.0.extract
  %8 = sdiv i64 %4, %.fca.1.1.0.0.0.extract
  %9 = add nsw i64 %8, 1
  %reass.add17 = add i64 %.neg16, %8
  %reass.add = sub i64 %6, %reass.add17
  %reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
  %10 = add nuw nsw i64 %4, 1
  %11 = add i64 %10, %reass.mul
  %12 = mul i64 %7, %.fca.1.1.0.1.0.extract
  %13 = add i64 %9, %12
  %14 = icmp sgt i64 %11, 0
  %15 = icmp sle i64 %11, %.fca.0.0.0.0.extract
  %16 = and i1 %14, %15
  %17 = icmp sgt i64 %13, 0
  %18 = icmp sle i64 %13, %.fca.0.0.1.0.extract
  %19 = and i1 %17, %18
  %20 = and i1 %19, %16
  br i1 %20, label %pass6, label %L527

pass6:                                            ; preds = %pass2
  %21 = add i64 %12, %8
  %22 = mul i64 %21, %.fca.2.0.extract
  %23 = add i64 %22, %4
  %24 = add i64 %23, %reass.mul
  %25 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
  %26 = getelementptr inbounds float, float addrspace(1)* %25, i64 %24
  store float %.fca.0.0.extract, float addrspace(1)* %26, align 4
  br label %L527
}

This results in a much higher register usage, from 28 to 64 (at the PTX level).

cc @vchuravy

@maleadt
Copy link
Member Author

maleadt commented Oct 18, 2024

Looks like most of the added code comes from KA's nditeration handlng:

pass2:                                            ; preds = %pass
; │└└└└└└└└└
; │┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:92 within `#threadIdx`
; ││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:46 within `threadIdx_x`
; │││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `_index`
; ││││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `macro expansion` @ /home/tim/.julia/packages/LLVM/joxPv/src/interop/base.jl:39
       %11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
; │└└└└
; │┌ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:84 within `expand`
; ││┌ @ abstractarray.jl:1312 within `getindex`
; │││┌ @ abstractarray.jl:1353 within `_getindex`
; ││││┌ @ abstractarray.jl:1360 within `_to_subscript_indices`
; │││││┌ @ abstractarray.jl:1382 within `_unsafe_ind2sub`
; ││││││┌ @ abstractarray.jl:3053 within `_ind2sub` @ abstractarray.jl:3091
; │││││││┌ @ int.jl:86 within `-`
          %12 = zext i32 %11 to i64
; │└└└└└└└
; │┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:78 within `#blockIdx`
; ││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:56 within `blockIdx_x`
; │││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `_index`
; ││││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `macro expansion` @ /home/tim/.julia/packages/LLVM/joxPv/src/interop/base.jl:39
       %13 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
; │└└└└
; │┌ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:84 within `expand`
; ││┌ @ abstractarray.jl:1312 within `getindex`
; │││┌ @ abstractarray.jl:1353 within `_getindex`
; ││││┌ @ abstractarray.jl:1360 within `_to_subscript_indices`
; │││││┌ @ abstractarray.jl:1382 within `_unsafe_ind2sub`
; ││││││┌ @ abstractarray.jl:3053 within `_ind2sub` @ abstractarray.jl:3091
; │││││││┌ @ int.jl:86 within `-`
          %14 = zext i32 %13 to i64
; │││││││└
; │││││││┌ @ abstractarray.jl:3104 within `_ind2sub_recurse`
; ││││││││┌ @ abstractarray.jl:3111 within `_div`
; │││││││││┌ @ int.jl:295 within `div`
            %15 = sdiv i64 %14, %.fca.1.0.0.0.0.extract
            %.neg17 = mul i64 %15, %.fca.1.0.0.0.0.extract
            %16 = sdiv i64 %12, %.fca.1.1.0.0.0.extract
; ││││││││└└
; ││││││││ @ abstractarray.jl:3105 within `_ind2sub_recurse` @ abstractarray.jl:3099
; ││││││││┌ @ abstractarray.jl:3109 within `_lookup`
; │││││││││┌ @ int.jl:87 within `+`
            %17 = add nsw i64 %16, 1
            %reass.add18 = add i64 %.neg17, %16
            %reass.add = sub i64 %14, %reass.add18
            %reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
; ││││││││└└
; ││││││││ @ abstractarray.jl:3105 within `_ind2sub_recurse`
; ││││││││┌ @ int.jl:87 within `+`
           %18 = add nuw nsw i64 %12, 1
; ││└└└└└└└
; ││ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:84 within `expand` @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:74
; ││┌ @ ntuple.jl:49 within `ntuple`
; │││┌ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:78 within `#1`
; ││││┌ @ int.jl:87 within `+`
       %19 = add i64 %18, %reass.mul
; ││││└
; ││││┌ @ int.jl:88 within `*`
       %20 = mul i64 %15, %.fca.1.1.0.1.0.extract
; ││││└
; ││││┌ @ int.jl:87 within `+`
       %21 = add i64 %17, %20
; │└└└└
; │ @ /home/tim/Julia/pkg/CUDA/src/CUDAKernels.jl:168 within `#__validindex`
; │┌ @ multidimensional.jl:477 within `in`
; ││┌ @ tuple.jl:383 within `map`
; │││┌ @ range.jl:1426 within `in`
; ││││┌ @ int.jl:514 within `<=`
       %22 = icmp sgt i64 %19, 0
       %23 = icmp sle i64 %19, %.fca.0.0.0.0.extract
; ││││└
; ││││┌ @ bool.jl:38 within `&`
       %24 = and i1 %22, %23
; ││││└
; ││││┌ @ int.jl:514 within `<=`
       %25 = icmp sgt i64 %21, 0
       %26 = icmp sle i64 %21, %.fca.0.0.1.0.extract
; ││││└
; ││││┌ @ bool.jl:38 within `&`
       %27 = and i1 %25, %26
; ││└└└
; ││┌ @ tuple.jl:664 within `all`
; │││┌ @ bool.jl:38 within `&`
      %28 = and i1 %27, %24
; └└└└
  br i1 %28, label %L242, label %L538

A couple of other things that stand out:

  • the KA.jl version has two sdivs
  • the CUDA.jl version computes a global ID using blockDim, while KA.jl somehow doesn't (I guess it never computes a global linear index?)

@maleadt
Copy link
Member Author

maleadt commented Oct 18, 2024

Testing on JuliaGPU/KernelAbstractions.jl#518, a bit of performance is recovered, but it remains bad:

Old: 29.800 μs (12 allocs: 240 bytes)
New: 56.749 μs (58 allocs: 1.859 KiB)

Updated MWE:

using CUDA, KernelAbstractions, Chairmarks
using LLVM, LLVM.Interop

function main()
    a = CuArray{Float32}(undef, 4000, 4000)
    bc = Broadcast.broadcasted(identity, 0f0)
    bc = Broadcast.instantiate(Broadcast.Broadcasted(bc.f, bc.args, axes(a)))
    print("Old: ")
    display(@b CUDA.@sync copyto_old!(a, bc))
    print("New: ")
    display(@b CUDA.@sync copyto_new!(a, bc))
end

@inline function copyto_old!(dest::AbstractArray, bc)
    axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
    isempty(dest) && return dest
    bc = Broadcast.preprocess(dest, bc)

    function broadcast_kernel(dest, bc)
        i = (blockIdx().x-1) * blockDim().x + threadIdx().x
        assume.(size(dest) .> 0)
        @inbounds if i <= length(dest)
            I = CartesianIndices(dest)[i]
            dest[I] = bc[I]
        end
        return
    end

    kernel = @cuda launch=false broadcast_kernel(dest, bc)
    config = launch_configuration(kernel.fun)
    threads = min(length(dest), config.threads)
    blocks = cld(length(dest), threads)
    kernel(dest, bc; threads, blocks)

    return dest
end

@inline function copyto_new!(dest::AbstractArray, bc)
    axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
    isempty(dest) && return dest
    bc = Broadcast.preprocess(dest, bc)

    @kernel function broadcast_kernel(dest, bc)
        I = @index(Global, Cartesian)
        @inbounds dest[I] = bc[I]
    end

    broadcast_kernel(get_backend(dest))(dest, bc; ndrange=size(dest))

    return dest
end

Old:

define ptx_kernel void @_Z16broadcast_kernel13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI17DefaultArrayStyleILi0EE5TupleI5OneToI5Int64ES8_E8identityS5_IS0_EE({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x float], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
  %.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
  %.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 1
  %.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
  %2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %3 = zext i32 %2 to i64
  %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
  %5 = zext i32 %4 to i64
  %6 = mul nuw nsw i64 %3, %5
  %7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %8 = add nuw nsw i32 %7, 1
  %9 = zext i32 %8 to i64
  %10 = add nuw nsw i64 %6, %9
  %11 = icmp sgt i64 %.fca.2.0.extract, 0
  call void @llvm.assume(i1 %11)
  %12 = icmp sgt i64 %.fca.2.1.extract, 0
  call void @llvm.assume(i1 %12)
  %.not = icmp sgt i64 %10, %.fca.3.extract
  br i1 %.not, label %L176, label %pass

L176:                                             ; preds = %pass, %conversion
  ret void

pass:                                             ; preds = %conversion
  %.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
  %13 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
  %14 = add nsw i64 %10, -1
  %15 = getelementptr inbounds float, float addrspace(1)* %13, i64 %14
  %.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %1, 0, 0
  store float %.fca.0.0.extract, float addrspace(1)* %15, align 4
  br label %L176
}

New:

define ptx_kernel void @_Z20gpu_broadcast_kernel16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi2E5TupleI5OneToI5Int64ES6_EE7NDRangeILi2ES0_S0_S8_S8_EE13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI17DefaultArrayStyleILi0EES7_8identityS3_ISD_EE({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x float], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
  %.fca.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 0, 0
  %.fca.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 1, 0
  %.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 0, 0
  %.fca.1.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 1, 0
  %.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 0, 0
  %.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 1, 0
  %3 = icmp sgt i64 %.fca.1.0.0.0.0.extract, 0
  call void @llvm.assume(i1 %3)
  %4 = icmp sgt i64 %.fca.1.0.0.1.0.extract, 0
  call void @llvm.assume(i1 %4)
  %5 = icmp sgt i64 %.fca.1.1.0.0.0.extract, 0
  call void @llvm.assume(i1 %5)
  %6 = icmp sgt i64 %.fca.1.1.0.1.0.extract, 0
  call void @llvm.assume(i1 %6)
  %7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %8 = zext i32 %7 to i64
  %9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %10 = zext i32 %9 to i64
  %11 = udiv i64 %10, %.fca.1.0.0.0.0.extract
  %.neg22 = mul i64 %11, %.fca.1.0.0.0.0.extract
  %12 = udiv i64 %8, %.fca.1.1.0.0.0.extract
  %13 = add nuw nsw i64 %12, 1
  %reass.add23 = add i64 %.neg22, %12
  %reass.add = sub i64 %10, %reass.add23
  %reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
  %14 = add nuw nsw i64 %8, 1
  %15 = add i64 %14, %reass.mul
  %16 = mul i64 %11, %.fca.1.1.0.1.0.extract
  %17 = add i64 %13, %16
  %18 = icmp sgt i64 %15, 0
  %19 = icmp sle i64 %15, %.fca.0.0.0.0.extract
  %20 = and i1 %18, %19
  %21 = icmp sgt i64 %17, 0
  %22 = icmp sle i64 %17, %.fca.0.0.1.0.extract
  %23 = and i1 %21, %22
  %24 = and i1 %23, %20
  br i1 %24, label %pass6, label %L585

L585:                                             ; preds = %pass6, %conversion
  ret void

pass6:                                            ; preds = %conversion
  %.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %2, 0, 0
  %.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
  %.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
  %25 = add i64 %16, %12
  %26 = mul i64 %25, %.fca.2.0.extract
  %27 = add i64 %26, %8
  %28 = add i64 %27, %reass.mul
  %29 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
  %30 = getelementptr inbounds float, float addrspace(1)* %29, i64 %28
  store float %.fca.0.0.extract, float addrspace(1)* %30, align 4
  br label %L585
}

The added arithmetic instructions are very clear now, and the udivs are probably still the culprit.


It's curious that the old version doesn't have any integer divisions to convert the linear index into a global one. Presumably this is possible because the array style of the broadcast is 0D (we're broadcasting a simple scalar), and when using KA.jl that somehow gets lost.

Indeed, when broadcasting an actual array (where the old code would have to udiv too), the performance is much closer:

function main()
    a = CuArray{Float32}(undef, 4000, 4000)
    b = CuArray{Float32}(undef, 4000, 4000)
    bc = Broadcast.broadcasted(identity, b)
    bc = Broadcast.instantiate(Broadcast.Broadcasted(bc.f, bc.args, axes(a)))
    print("Old: ")
    display(@b CUDA.@sync copyto_old!(a, bc))
    print("New: ")
    display(@b CUDA.@sync copyto_new!(a, bc))
end
Old: 153.629 μs (42 allocs: 720 bytes)
New: 161.238 μs (88 allocs: 2.547 KiB)

Old:

define ptx_kernel void @_Z16broadcast_kernel13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI12CuArrayStyleILi2E12DeviceMemoryE5TupleI5OneToI5Int64ES9_E8identityS6_I8ExtrudedIS1_S6_I4BoolSD_ES6_IS8_S8_EEEE({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
  %.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
  %.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 1
  %.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
  %2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %3 = zext i32 %2 to i64
  %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
  %5 = zext i32 %4 to i64
  %6 = mul nuw nsw i64 %3, %5
  %7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %8 = add nuw nsw i32 %7, 1
  %9 = zext i32 %8 to i64
  %10 = add nuw nsw i64 %6, %9
  %11 = icmp sgt i64 %.fca.2.0.extract, 0
  call void @llvm.assume(i1 %11)
  %12 = icmp sgt i64 %.fca.2.1.extract, 0
  call void @llvm.assume(i1 %12)
  %.not = icmp sgt i64 %10, %.fca.3.extract
  br i1 %.not, label %L221, label %pass

L221:                                             ; preds = %pass, %conversion
  ret void

pass:                                             ; preds = %conversion
  %.fca.0.0.2.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 2, 1
  %.fca.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 2, 0
  %.fca.0.0.1.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 1, 1
  %.fca.0.0.1.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 1, 0
  %.fca.0.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 0, 2, 0
  %.fca.0.0.0.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 0, 0
  %.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
  %13 = add nsw i64 %10, -1
  %14 = udiv i64 %13, %.fca.2.0.extract
  %15 = mul i64 %14, %.fca.2.0.extract
  %16 = sub i64 %13, %15
  %17 = add i64 %16, 1
  %18 = add nuw nsw i64 %14, 1
  %19 = and i8 %.fca.0.0.1.0.extract, 1
  %.not7 = icmp eq i8 %19, 0
  %20 = select i1 %.not7, i64 %.fca.0.0.2.0.extract, i64 %17
  %21 = and i8 %.fca.0.0.1.1.extract, 1
  %.not8 = icmp eq i8 %21, 0
  %22 = select i1 %.not8, i64 %.fca.0.0.2.1.extract, i64 %18
  %23 = add i64 %22, -1
  %24 = mul i64 %23, %.fca.0.0.0.2.0.extract
  %25 = add i64 %24, -1
  %26 = add i64 %25, %20
  %27 = bitcast i8 addrspace(1)* %.fca.0.0.0.0.extract to float addrspace(1)*
  %28 = getelementptr inbounds float, float addrspace(1)* %27, i64 %26
  %29 = load float, float addrspace(1)* %28, align 4
  %30 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
  %31 = getelementptr inbounds float, float addrspace(1)* %30, i64 %13
  store float %29, float addrspace(1)* %31, align 4
  br label %L221
}

Notice the udiv.

KA.jl:

define ptx_kernel void @_Z20gpu_broadcast_kernel16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi2E5TupleI5OneToI5Int64ES6_EE7NDRangeILi2ES0_S0_S8_S8_EE13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI12CuArrayStyleILi2E12DeviceMemoryES7_8identityS3_I8ExtrudedISE_S3_I4BoolSL_ES3_IS5_S5_EEEE({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
  %.fca.0.0.0.0.extract6 = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 0, 0
  %.fca.0.0.1.0.extract7 = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 1, 0
  %.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 0, 0
  %.fca.1.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 1, 0
  %.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 0, 0
  %.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 1, 0
  %3 = icmp sgt i64 %.fca.1.0.0.0.0.extract, 0
  call void @llvm.assume(i1 %3)
  %4 = icmp sgt i64 %.fca.1.0.0.1.0.extract, 0
  call void @llvm.assume(i1 %4)
  %5 = icmp sgt i64 %.fca.1.1.0.0.0.extract, 0
  call void @llvm.assume(i1 %5)
  %6 = icmp sgt i64 %.fca.1.1.0.1.0.extract, 0
  call void @llvm.assume(i1 %6)
  %7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %8 = zext i32 %7 to i64
  %9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %10 = zext i32 %9 to i64
  %11 = udiv i64 %10, %.fca.1.0.0.0.0.extract
  %.neg28 = mul i64 %11, %.fca.1.0.0.0.0.extract
  %12 = udiv i64 %8, %.fca.1.1.0.0.0.extract
  %13 = add nuw nsw i64 %12, 1
  %reass.add29 = add i64 %.neg28, %12
  %reass.add = sub i64 %10, %reass.add29
  %reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
  %14 = add nuw nsw i64 %8, 1
  %15 = add i64 %14, %reass.mul
  %16 = mul i64 %11, %.fca.1.1.0.1.0.extract
  %17 = add i64 %13, %16
  %18 = icmp sgt i64 %15, 0
  %19 = icmp sle i64 %15, %.fca.0.0.0.0.extract6
  %20 = and i1 %18, %19
  %21 = icmp sgt i64 %17, 0
  %22 = icmp sle i64 %17, %.fca.0.0.1.0.extract7
  %23 = and i1 %21, %22
  %24 = and i1 %23, %20
  br i1 %24, label %pass6, label %L630

L630:                                             ; preds = %pass6, %conversion
  ret void

pass6:                                            ; preds = %conversion
  %.fca.0.0.2.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 2, 1
  %.fca.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 2, 0
  %.fca.0.0.1.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 1, 1
  %.fca.0.0.1.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 1, 0
  %.fca.0.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 0, 2, 0
  %.fca.0.0.0.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 0, 0
  %.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
  %.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
  %25 = sub i64 %10, %.neg28
  %26 = mul i64 %12, %.fca.1.1.0.0.0.extract
  %27 = sub i64 %8, %26
  %28 = mul i64 %25, %.fca.1.1.0.0.0.extract
  %29 = add i64 %28, %27
  %30 = add i64 %29, 1
  %31 = add i64 %16, %12
  %32 = add i64 %31, 1
  %33 = and i8 %.fca.0.0.1.0.extract, 1
  %.not = icmp eq i8 %33, 0
  %34 = select i1 %.not, i64 %.fca.0.0.2.0.extract, i64 %30
  %35 = and i8 %.fca.0.0.1.1.extract, 1
  %.not27 = icmp eq i8 %35, 0
  %36 = select i1 %.not27, i64 %.fca.0.0.2.1.extract, i64 %32
  %37 = add i64 %36, -1
  %38 = mul i64 %37, %.fca.0.0.0.2.0.extract
  %39 = add i64 %38, -1
  %40 = add i64 %39, %34
  %41 = bitcast i8 addrspace(1)* %.fca.0.0.0.0.extract to float addrspace(1)*
  %42 = getelementptr inbounds float, float addrspace(1)* %41, i64 %40
  %43 = load float, float addrspace(1)* %42, align 4
  %44 = mul i64 %31, %.fca.2.0.extract
  %45 = add i64 %29, %44
  %46 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
  %47 = getelementptr inbounds float, float addrspace(1)* %46, i64 %45
  store float %43, float addrspace(1)* %47, align 4
  br label %L630
}

Still a lot more code, and and additional udiv, but at least it makes the hypothesis more likely.

@maleadt
Copy link
Member Author

maleadt commented Oct 18, 2024

JuliaGPU/KernelAbstractions.jl#539 results in only a single sdiv, which is an improvement, but still not the 0 div case reported here (which comes from broadcasting a scalar).

@maleadt
Copy link
Member Author

maleadt commented Oct 21, 2024

With JuliaGPU/KernelAbstractions.jl#539, performance is interestingly worse. For the scalar broadcast:

julia> main()
Old: 30.260 μs (12 allocs: 240 bytes)
New: 72.179 μs (53 allocs: 1.891 KiB)

(as opposed to 61us on master, and 56us on JuliaGPU/KernelAbstractions.jl#518)

For a 2D broadcast:

julia> main()
Old: 154.619 μs (42 allocs: 720 bytes)
New: 165.908 μs (84 allocs: 2.594 KiB)

... which is again a bit slower than before.


Adding some assume calls to get rid of all exceptions (both the div related one, the newly added DivError, and a remaining InexactError), I get:

julia> main()
Old: 28.200 μs (12 allocs: 240 bytes)
New: 50.119 μs (53 allocs: 1.891 KiB)
julia> main()
Old: 155.308 μs (42 allocs: 720 bytes)
New: 160.239 μs (83 allocs: 2.578 KiB)

@maleadt
Copy link
Member Author

maleadt commented Jan 13, 2025

To summarize, there seems to be (at least) three areas to improve / optimize:

  1. scalar broadcast now having udiv: somehow the scalar nature of the broadcast is getting lost

  2. more complex linear index calculations, relying on block size from arguments instead of using hardware indices: maybe this is solved by Try fast linear indexes for KA CUDA.jl#2612

  3. moving towards using ND hardware indices to avoid the linear->cartesian computation where possible

These are all visible in the MWE from #565 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant