Skip to content

Commit

Permalink
Adapt to GPUArrays.jl transition to KernelAbstractions.jl. (#461)
Browse files Browse the repository at this point in the history
Co-authored-by: James Schloss <[email protected]>
  • Loading branch information
maleadt and leios authored Oct 18, 2024
1 parent 100f831 commit 711758d
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ BFloat16s = "0.5"
CEnum = "0.4, 0.5"
CodecBzip2 = "0.8"
ExprTools = "0.1"
GPUArrays = "10.1"
GPUArrays = "11"
GPUCompiler = "0.26, 0.27, 1"
KernelAbstractions = "0.9.1"
LLVM = "7.2, 8, 9"
Expand Down
54 changes: 0 additions & 54 deletions src/gpuarrays.jl
Original file line number Diff line number Diff line change
@@ -1,59 +1,5 @@
## GPUArrays interfaces

## execution

struct mtlArrayBackend <: AbstractGPUBackend end

struct mtlKernelContext <: AbstractKernelContext end

@inline function GPUArrays.launch_heuristic(::mtlArrayBackend, f::F, args::Vararg{Any,N};
elements::Int, elements_per_thread::Int) where {F,N}
kernel = @metal launch=false f(mtlKernelContext(), args...)

# The pipeline state automatically computes occupancy stats
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
blocks = cld(elements, threads)

return (; threads=Int(threads), blocks=Int(blocks))
end

function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, groups::Int;
name::Union{String,Nothing})
@metal threads groups name f(mtlKernelContext(), args...)
end


## on-device

# indexing
GPUArrays.blockidx(ctx::mtlKernelContext) = threadgroup_position_in_grid_1d()
GPUArrays.blockdim(ctx::mtlKernelContext) = threads_per_threadgroup_1d()
GPUArrays.threadidx(ctx::mtlKernelContext) = thread_position_in_threadgroup_1d()
GPUArrays.griddim(ctx::mtlKernelContext) = threadgroups_per_grid_1d()
GPUArrays.global_index(ctx::mtlKernelContext) = thread_position_in_grid_1d()
GPUArrays.global_size(ctx::mtlKernelContext) = threads_per_grid_1d()

# memory

@inline function GPUArrays.LocalMemory(::mtlKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}
) where {T, dims, id}
ptr = emit_threadgroup_memory(T, Val(prod(dims)))
MtlDeviceArray(dims, ptr)
end

# synchronization

@inline GPUArrays.synchronize_threads(::mtlKernelContext) =
threadgroup_barrier(MemoryFlagThreadGroup)



#
# Host abstractions
#

GPUArrays.backend(::Type{<:MtlArray}) = mtlArrayBackend()

const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}()
function GPUArrays.default_rng(::Type{<:MtlArray})
dev = device()
Expand Down
3 changes: 1 addition & 2 deletions test/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
a = f(T, d)
Metal.seed!(1)
b = f(T, d)
# TODO: Remove broken parameter once https://github.com/JuliaGPU/GPUArrays.jl/issues/530 is fixed
@test Array(a) == Array(b) broken = (T == Float16 && d == (1000,1000))
@test Array(a) == Array(b)
end
end
end # testset

0 comments on commit 711758d

Please sign in to comment.