From c3705230bf7b58a38684155af0414743ef1bf5e2 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 12 Feb 2023 20:12:25 -0500 Subject: [PATCH] Transition GPUArrays to KernelAbstractions Co-authored-by: James Schloss --- .buildkite/pipeline.yml | 11 +- Project.toml | 1 + docs/src/index.md | 5 +- docs/src/interface.md | 57 +++---- lib/GPUArraysCore/src/GPUArraysCore.jl | 6 +- lib/JLArrays/Project.toml | 3 +- lib/JLArrays/src/JLArrays.jl | 198 ++++++++++------------- src/GPUArrays.jl | 8 +- src/device/execution.jl | 83 +--------- src/device/indexing.jl | 85 ---------- src/device/memory.jl | 27 ---- src/device/synchronization.jl | 13 -- src/host/abstractarray.jl | 32 ++-- src/host/base.jl | 28 ++-- src/host/broadcast.jl | 95 +++++------- src/host/construction.jl | 28 ++-- src/host/indexing.jl | 26 ++-- src/host/linalg.jl | 207 ++++++++++++------------- src/host/math.jl | 6 +- src/host/random.jl | 38 ++--- src/host/uniformscaling.jl | 28 ++-- test/Project.toml | 1 + test/runtests.jl | 17 ++ test/testsuite.jl | 2 +- test/testsuite/base.jl | 27 ++-- test/testsuite/broadcasting.jl | 3 +- test/testsuite/gpuinterface.jl | 47 ------ 27 files changed, 383 insertions(+), 699 deletions(-) delete mode 100644 src/device/indexing.jl delete mode 100644 src/device/memory.jl delete mode 100644 src/device/synchronization.jl delete mode 100644 test/testsuite/gpuinterface.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index b7f4aa82..411ca23e 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -10,7 +10,7 @@ steps: println("--- :julia: Instantiating project") Pkg.develop(; path=pwd()) - Pkg.develop(; name="CUDA") + Pkg.add(; url="https://github.com/leios/CUDA.jl/", rev="GtK_trans") println("+++ :julia: Running tests") Pkg.test("CUDA"; coverage=true)' @@ -31,10 +31,13 @@ steps: println("--- :julia: Instantiating project") Pkg.develop(; path=pwd()) - Pkg.develop(; name="oneAPI") + Pkg.add(; url="https://github.com/leios/oneAPI.jl/", rev="GtK_transition") println("+++ :julia: Building support library") - include(joinpath(Pkg.devdir(), "oneAPI", "deps", "build_ci.jl")) + filename = Base.find_package("oneAPI") + filename = filename[1:findfirst("oneAPI.jl", filename)[1]-1] + filename *= "../deps/build_ci.jl" + include(filename) Pkg.activate() println("+++ :julia: Running tests") @@ -56,7 +59,7 @@ steps: println("--- :julia: Instantiating project") Pkg.develop(; path=pwd()) - Pkg.develop(; name="Metal") + Pkg.add(; url="https://github.com/leios/Metal.jl/", rev="GtK_transition") println("+++ :julia: Running tests") Pkg.test("Metal"; coverage=true)' diff --git a/Project.toml b/Project.toml index e3bbe181..e1e8d775 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "10.2.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" diff --git a/docs/src/index.md b/docs/src/index.md index cfdd3272..8cb100ca 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -9,10 +9,9 @@ will get a lot of functionality for free. This will allow to have multiple GPUAr implementation for different purposes, while maximizing the ability to share code. **This package is not intended for end users!** Instead, you should use one of the packages -that builds on GPUArrays.jl. There is currently only a single package that actively builds -on these interfaces, namely [CuArrays.jl](https://github.com/JuliaGPU/CuArrays.jl). +that builds on GPUArrays.jl such as [CUDA](https://github.com/JuliaGPU/CUDA.jl), [AMDGPU](https://github.com/JuliaGPU/AMDGPU.jl), [OneAPI](https://github.com/JuliaGPU/oneAPI.jl), or [Metal](https://github.com/JuliaGPU/Metal.jl). -In this documentation, you will find more information on the interface that you are expected +This documentation is meant for users who might wish to implement a version of GPUArrays for another GPU backend and will cover the features you will need to implement, the functionality you gain by doing so, and the test suite that is available to verify your implementation. GPUArrays.jl also provides a reference implementation of these interfaces on the CPU: The `JLArray` array type uses Julia's parallel programming diff --git a/docs/src/interface.md b/docs/src/interface.md index 01c0a3c9..239bef87 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -1,53 +1,32 @@ # Interface To extend the above functionality to a new array type, you should use the types and -implement the interfaces listed on this page. GPUArrays is design around having two -different array types to represent a GPU array: one that only ever lives on the host, and +implement the interfaces listed on this page. GPUArrays is designed around having two +different array types to represent a GPU array: one that exists only on the host, and one that actually can be instantiated on the device (i.e. in kernels). +Device functionality is then handled by [KernelAbstractions.jl](https://github.com/JuliaGPU/KernelAbstractions.jl). +## Host abstractions -## Device functionality - -Several types and interfaces are related to the device and execution of code on it. First of -all, you need to provide a type that represents your execution back-end and a way to call -kernels: +You should provide an array type that builds on the `AbstractGPUArray` supertype, such as: -```@docs -GPUArrays.AbstractGPUBackend -GPUArrays.AbstractKernelContext -GPUArrays.gpu_call -GPUArrays.thread_block_heuristic ``` +mutable struct CustomArray{T, N} <: AbstractGPUArray{T, N} + data::DataRef{Vector{UInt8}} + offset::Int + dims::Dims{N} + ... +end -You then need to provide implementations of certain methods that will be executed on the -device itself: - -```@docs -GPUArrays.AbstractDeviceArray -GPUArrays.LocalMemory -GPUArrays.synchronize_threads -GPUArrays.blockidx -GPUArrays.blockdim -GPUArrays.threadidx -GPUArrays.griddim ``` +This will allow your defined type (in this case `JLArray`) to use the GPUArrays interface where available. +To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you need to define the backend, like so: -## Host abstractions - -You should provide an array type that builds on the `AbstractGPUArray` supertype: - -```@docs -AbstractGPUArray ``` - -First of all, you should implement operations that are expected to be defined for any -`AbstractArray` type. Refer to the Julia manual for more details, or look at the `JLArray` -reference implementation. - -To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you -should provide implementations of the following interfaces: - -```@docs -GPUArrays.backend +import KernelAbstractions: Backend +struct CustomBackend <: KernelAbstractions.GPU +KernelAbstractions.get_backend(a::CA) where CA <: CustomArray = CustomBackend() ``` + +There are numerous examples of potential interfaces for GPUArrays, such as with [JLArrays](https://github.com/JuliaGPU/GPUArrays.jl/blob/master/lib/JLArrays/src/JLArrays.jl), [CuArrays](https://github.com/JuliaGPU/CUDA.jl/blob/master/src/gpuarrays.jl), and [ROCArrays](https://github.com/JuliaGPU/AMDGPU.jl/blob/master/src/gpuarrays.jl). diff --git a/lib/GPUArraysCore/src/GPUArraysCore.jl b/lib/GPUArraysCore/src/GPUArraysCore.jl index 0d0b5182..d3be4e2e 100644 --- a/lib/GPUArraysCore/src/GPUArraysCore.jl +++ b/lib/GPUArraysCore/src/GPUArraysCore.jl @@ -222,10 +222,10 @@ end Gets the GPUArrays back-end responsible for managing arrays of type `T`. """ -backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE -backend(x) = backend(typeof(x)) +get_backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE +get_backend(x) = get_backend(typeof(x)) # WrappedArray from Adapt for Base wrappers. -backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA)) +get_backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA)) end # module GPUArraysCore diff --git a/lib/JLArrays/Project.toml b/lib/JLArrays/Project.toml index ce8959b7..edae1492 100644 --- a/lib/JLArrays/Project.toml +++ b/lib/JLArrays/Project.toml @@ -6,10 +6,11 @@ version = "0.1.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] Adapt = "2.0, 3.0, 4.0" GPUArrays = "10" -julia = "1.8" Random = "1" +julia = "1.8" diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index 5098cc3f..c529b9ec 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -1,16 +1,17 @@ # reference implementation on the CPU - -# note that most of the code in this file serves to define a functional array type, -# the actual implementation of GPUArrays-interfaces is much more limited. +# This acts as a wrapper around KernelAbstractions's parallel CPU +# functionality. It is useful for testing GPUArrays (and other packages) +# when no GPU is present. +# This file follows conventions from AMDGPU.jl module JLArrays -export JLArray, JLVector, JLMatrix, jl - using GPUArrays - using Adapt +import KernelAbstractions +import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config +export JLArray, JLVector, JLMatrix, jl, JLBackend # # Device functionality @@ -18,37 +19,11 @@ using Adapt const MAXTHREADS = 256 - -## execution - -struct JLBackend <: AbstractGPUBackend end - -mutable struct JLKernelContext <: AbstractKernelContext - blockdim::Int - griddim::Int - blockidx::Int - threadidx::Int - - localmem_counter::Int - localmems::Vector{Vector{Array}} -end - -function JLKernelContext(threads::Int, blockdim::Int) - blockcount = prod(blockdim) - lmems = [Vector{Array}() for i in 1:blockcount] - JLKernelContext(threads, blockdim, 1, 1, 0, lmems) +struct JLBackend <: KernelAbstractions.GPU + static::Bool + JLBackend(;static::Bool=false) = new(static) end -function JLKernelContext(ctx::JLKernelContext, threadidx::Int) - JLKernelContext( - ctx.blockdim, - ctx.griddim, - ctx.blockidx, - threadidx, - 0, - ctx.localmems - ) -end struct Adaptor end jlconvert(arg) = adapt(Adaptor(), arg) @@ -60,28 +35,35 @@ end Base.getindex(r::JlRefValue) = r.x Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[])) -function GPUArrays.gpu_call(::JLBackend, f, args, threads::Int, blocks::Int; - name::Union{String,Nothing}) - ctx = JLKernelContext(threads, blocks) - device_args = jlconvert.(args) - tasks = Array{Task}(undef, threads) - for blockidx in 1:blocks - ctx.blockidx = blockidx - for threadidx in 1:threads - thread_ctx = JLKernelContext(ctx, threadidx) - tasks[threadidx] = @async f(thread_ctx, device_args...) - # TODO: require 1.3 and use Base.Threads.@spawn for actual multithreading - # (this would require a different synchronization mechanism) - end - for t in tasks - fetch(t) - end +mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} + data::DataRef{Vector{UInt8}} + + offset::Int # offset of the data in the buffer, in number of elements + + dims::Dims{N} + + # allocating constructor + function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} + check_eltype(T) + maxsize = prod(dims) * sizeof(T) + data = Vector{UInt8}(undef, maxsize) + ref = DataRef(data) + obj = new{T,N}(ref, 0, dims) + finalizer(unsafe_free!, obj) end - return -end + # low-level constructor for wrapping existing data + function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N}; + offset::Int=0) where {T,N} + check_eltype(T) + obj = new{T,N}(ref, offset, dims) + finalizer(unsafe_free!, obj) + end +end -## executed on-device +Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a) +Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a +Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a) # array type @@ -107,43 +89,6 @@ end @inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(typed_data(A), index) @inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(typed_data(A), x, index) - -# indexing - -for f in (:blockidx, :blockdim, :threadidx, :griddim) - @eval GPUArrays.$f(ctx::JLKernelContext) = ctx.$f -end - -# memory - -function GPUArrays.LocalMemory(ctx::JLKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id} - ctx.localmem_counter += 1 - lmems = ctx.localmems[blockidx(ctx)] - - # first invocation in block - data = if length(lmems) < ctx.localmem_counter - lmem = fill(zero(T), dims) - push!(lmems, lmem) - lmem - else - lmems[ctx.localmem_counter] - end - - N = length(dims) - JLDeviceArray{T,N}(data, tuple(dims...)) -end - -# synchronization - -@inline function GPUArrays.synchronize_threads(::JLKernelContext) - # All threads are getting started asynchronously, so a yield will yield to the next - # execution of the same function, which should call yield at the exact same point in the - # program, leading to a chain of yields effectively syncing the tasks (threads). - yield() - return -end - - # # Host abstractions # @@ -157,32 +102,6 @@ function check_eltype(T) end end -mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} - data::DataRef{Vector{UInt8}} - - offset::Int # offset of the data in the buffer, in number of elements - - dims::Dims{N} - - # allocating constructor - function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} - check_eltype(T) - maxsize = prod(dims) * sizeof(T) - data = Vector{UInt8}(undef, maxsize) - ref = DataRef(data) - obj = new{T,N}(ref, 0, dims) - finalizer(unsafe_free!, obj) - end - - # low-level constructor for wrapping existing data - function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N}; - offset::Int=0) where {T,N} - check_eltype(T) - obj = new{T,N}(ref, offset, dims) - finalizer(unsafe_free!, obj) - end -end - unsafe_free!(a::JLArray) = GPUArrays.unsafe_free!(a.data) # conversion of untyped data to a typed Array @@ -400,8 +319,6 @@ end ## GPUArrays interfaces -GPUArrays.backend(::Type{<:JLArray}) = JLBackend() - Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} = JLDeviceArray{T,N}(x.data[], x.offset, x.dims) @@ -414,4 +331,47 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br R end +## KernelAbstractions interface + +KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend() + +function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic + return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace) +end + +KernelAbstractions.allocate(::JLBackend, ::Type{T}, dims::Tuple) where T = JLArray{T}(undef, dims) + +@inline function launch_config(kernel::Kernel{JLBackend}, ndrange, workgroupsize) + if ndrange isa Integer + ndrange = (ndrange,) + end + if workgroupsize isa Integer + workgroupsize = (workgroupsize, ) + end + + if KernelAbstractions.workgroupsize(kernel) <: DynamicSize && workgroupsize === nothing + workgroupsize = (1024,) # Vectorization, 4x unrolling, minimal grain size + end + iterspace, dynamic = partition(kernel, ndrange, workgroupsize) + # partition checked that the ndrange's agreed + if KernelAbstractions.ndrange(kernel) <: StaticSize + ndrange = nothing + end + + return ndrange, workgroupsize, iterspace, dynamic +end + +KernelAbstractions.isgpu(b::JLBackend) = false + +function convert_to_cpu(obj::Kernel{JLBackend, W, N, F}) where {W, N, F} + return Kernel{typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F}(KernelAbstractions.CPU(; static = obj.backend.static), obj.f) +end + +function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothing) + device_args = jlconvert.(args) + new_obj = convert_to_cpu(obj) + new_obj(device_args...; ndrange, workgroupsize) + +end + end diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl index 2d4f1bd9..54e9c877 100644 --- a/src/GPUArrays.jl +++ b/src/GPUArrays.jl @@ -1,5 +1,6 @@ module GPUArrays +using KernelAbstractions using Serialization using Random using LinearAlgebra @@ -14,14 +15,11 @@ using LLVM.Interop using Reexport @reexport using GPUArraysCore -# device functionality -include("device/execution.jl") ## executed on-device +include("device/execution.jl") include("device/abstractarray.jl") -include("device/indexing.jl") -include("device/memory.jl") -include("device/synchronization.jl") +using KernelAbstractions # host abstractions include("host/abstractarray.jl") include("host/construction.jl") diff --git a/src/device/execution.jl b/src/device/execution.jl index 41285bc3..64a81dad 100644 --- a/src/device/execution.jl +++ b/src/device/execution.jl @@ -1,75 +1,5 @@ # kernel execution -export AbstractGPUBackend, AbstractKernelContext, gpu_call - -abstract type AbstractGPUBackend end - -abstract type AbstractKernelContext end - -import GPUArraysCore: backend - -""" - gpu_call(kernel::Function, arg0, args...; kwargs...) - -Executes `kernel` on the device that backs `arg` (see [`backend`](@ref)), passing along any -arguments `args`. Additionally, the kernel will be passed the kernel execution context (see -[`AbstractKernelContext`]), so its signature should be `(ctx::AbstractKernelContext, arg0, -args...)`. - -The keyword arguments `kwargs` are not passed to the function, but are interpreted on the -host to influence how the kernel is executed. The following keyword arguments are supported: - -- `target::AbstractArray`: specify which array object to use for determining execution - properties (defaults to the first argument `arg0`). -- `elements::Int`: how many elements will be processed by this kernel. In most - circumstances, this will correspond to the total number of threads that needs to be - launched, unless the kernel supports a variable number of elements to process per - iteration. Defaults to the length of `arg0` if no other keyword arguments that influence - the launch configuration are specified. -- `threads::Int` and `blocks::Int`: configure exactly how many threads and blocks are - launched. This cannot be used in combination with the `elements` argument. -- `name::String`: inform the back end about the name of the kernel to be executed. This can - be used to emit better diagnostics, and is useful with anonymous kernels. -""" -function gpu_call(kernel::F, args::Vararg{Any,N}; - target::AbstractArray=first(args), - elements::Union{Int,Nothing}=nothing, - threads::Union{Int,Nothing}=nothing, - blocks::Union{Int,Nothing}=nothing, - name::Union{String,Nothing}=nothing) where {F,N} - # non-trivial default values for launch configuration - if elements===nothing && threads===nothing && blocks===nothing - elements = length(target) - elseif elements===nothing - if threads === nothing - threads = 1 - end - if blocks === nothing - blocks = 1 - end - elseif threads!==nothing || blocks!==nothing - error("Cannot specify both elements and threads/blocks configuration") - end - - # the number of elements to process needs to be passed to the kernel somehow, so there's - # no easy way to do this without passing additional arguments or changing the context. - # both are expensive, so require manual use of `launch_heuristic` for those kernels. - elements_per_thread = 1 - - if elements !== nothing - @assert elements > 0 - heuristic = launch_heuristic(backend(target), kernel, args...; - elements, elements_per_thread) - config = launch_configuration(backend(target), heuristic; - elements, elements_per_thread) - gpu_call(backend(target), kernel, args, config.threads, config.blocks; name=name) - else - @assert threads > 0 - @assert blocks > 0 - gpu_call(backend(target), kernel, args, threads, blocks; name=name) - end -end - # how many threads and blocks `kernel` needs to be launched with, passing arguments `args`, # to fully saturate the GPU. `elements` indicates the number of elements that needs to be # processed, while `elements_per_threads` indicates the number of elements this kernel can @@ -77,16 +7,18 @@ end # # this heuristic should be specialized for the back-end, ideally using an API for maximizing # the occupancy of the launch configuration (like CUDA's occupancy API). -function launch_heuristic(backend::AbstractGPUBackend, kernel, args...; - elements::Int, elements_per_thread::Int) +function launch_heuristic(backend::B, kernel, args...; + elements::Int, + elements_per_thread::Int) where B <: Backend return (threads=256, blocks=32) end # determine how many threads and blocks to actually launch given upper limits. # returns a tuple of blocks, threads, and elements_per_thread (which is always 1 # unless specified that the kernel can handle a number of elements per thread) -function launch_configuration(backend::AbstractGPUBackend, heuristic; - elements::Int, elements_per_thread::Int) +function launch_configuration(backend::B, heuristic; + elements::Int, + elements_per_thread::Int) where B <: Backend threads = clamp(elements, 1, heuristic.threads) blocks = max(cld(elements, threads), 1) @@ -105,6 +37,3 @@ function launch_configuration(backend::AbstractGPUBackend, heuristic; (; threads, blocks, elements_per_thread=1) end end - -gpu_call(backend::AbstractGPUBackend, kernel, args, threads::Int, blocks::Int; kwargs...) = - error("Not implemented") # COV_EXCL_LINE diff --git a/src/device/indexing.jl b/src/device/indexing.jl deleted file mode 100644 index 31084fce..00000000 --- a/src/device/indexing.jl +++ /dev/null @@ -1,85 +0,0 @@ -# indexing - -export global_index, global_size, linear_index, @linearidx, @cartesianidx - - -## hardware - -for f in (:blockidx, :blockdim, :threadidx, :griddim) - @eval $f(ctx::AbstractKernelContext)::Int = error("Not implemented") # COV_EXCL_LINE - @eval export $f -end - -""" - global_index(ctx::AbstractKernelContext) - -Query the global index of the current thread in the launch configuration (i.e. as far as the -hardware is concerned). -""" -@inline function global_index(ctx::AbstractKernelContext) - threadidx(ctx) + (blockidx(ctx) - 1) * blockdim(ctx) -end - -""" - global_size(ctx::AbstractKernelContext) - -Query the global size of the launch configuration (total number of threads launched). -""" -@inline function global_size(ctx::AbstractKernelContext) - griddim(ctx) * blockdim(ctx) -end - - -## logical - -""" - linear_index(ctx::AbstractKernelContext, grididx::Int=1) - -Return a linear index for the current kernel by querying hardware registers (similar to -`get_global_id` in OpenCL). For applying a grid stride (in terms of [`global_size`](@ref)), -specify `grididx`. - -""" -@inline function linear_index(ctx::AbstractKernelContext, grididx::Int=1) - global_index(ctx) + (grididx - 1) * global_size(ctx) -end - -""" - linearidx(A, grididx=1, ctxsym=:ctx) - -Macro form of [`linear_index`](@ref), which return from the surrouunding scope when out of -bounds: - - ```julia - function kernel(ctx::AbstractKernelContext, A) - idx = @linearidx A - # from here on it's safe to index into A with idx - @inbounds begin - A[idx] = ... - end - end - ``` -""" -macro linearidx(A, grididx=1, ctxsym=:ctx) - quote - x = $(esc(A)) - i = linear_index($(esc(ctxsym)), $(esc(grididx))) - if !(1 <= i <= length(x)) - return - end - i - end -end - -""" - cartesianidx(A, grididx=1, ctxsym=:ctx) - -Like [`@linearidx`](@ref), but returns a N-dimensional `CartesianIndex`. -""" -macro cartesianidx(A, grididx=1, ctxsym=:ctx) - quote - x = $(esc(A)) - i = @linearidx(x, $(esc(grididx)), $(esc(ctxsym))) - @inbounds CartesianIndices(x)[i] - end -end diff --git a/src/device/memory.jl b/src/device/memory.jl deleted file mode 100644 index 901791d5..00000000 --- a/src/device/memory.jl +++ /dev/null @@ -1,27 +0,0 @@ -# on-device memory management - -export @LocalMemory - - -## thread-local array - -""" -Creates a local static memory shared inside one block. -Equivalent to `__local` of OpenCL or `__shared__ ()` of CUDA. -""" -macro LocalMemory(ctx, T, N) - id = gensym("local_memory") - quote - LocalMemory($(esc(ctx)), $(esc(T)), Val($(esc(N))), Val($(QuoteNode(id)))) - end -end - -""" -Creates a block local array pointer with `T` being the element type -and `N` the length. Both T and N need to be static! C is a counter for -approriately get the correct Local mem id in CUDAnative. -This is an internal method which needs to be overloaded by the GPU Array backends -""" -function LocalMemory(ctx, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id} - error("Not implemented") # COV_EXCL_LINE -end diff --git a/src/device/synchronization.jl b/src/device/synchronization.jl deleted file mode 100644 index b16d2518..00000000 --- a/src/device/synchronization.jl +++ /dev/null @@ -1,13 +0,0 @@ -# synchronization - -export synchronize_threads - -""" - synchronize_threads(ctx::AbstractKernelContext) - -in CUDA terms `__synchronize` -in OpenCL terms: `barrier(CLK_LOCAL_MEM_FENCE)` -""" -function synchronize_threads(ctx::AbstractKernelContext) - error("Not implemented") # COV_EXCL_LINE -end diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl index 291b2184..d5ce1fb5 100644 --- a/src/host/abstractarray.jl +++ b/src/host/abstractarray.jl @@ -173,13 +173,12 @@ for (D, S) in ((AnyGPUArray, Array), end # kernel-based variant for copying between wrapped GPU arrays - -function linear_copy_kernel!(ctx::AbstractKernelContext, dest, dstart, src, sstart, n) - i = linear_index(ctx)-1 - if i < n - @inbounds dest[dstart+i] = src[sstart+i] +# TODO: Add `@Const` to `src` +@kernel function linear_copy_kernel!(dest, dstart, src, sstart, n) + i = @index(Global, Linear) + if i <= n + @inbounds dest[dstart+i-1] = src[sstart+i-1] end - return end function Base.copyto!(dest::AnyGPUArray, dstart::Integer, @@ -189,10 +188,8 @@ function Base.copyto!(dest::AnyGPUArray, dstart::Integer, destinds, srcinds = LinearIndices(dest), LinearIndices(src) (checkbounds(Bool, destinds, dstart) && checkbounds(Bool, destinds, dstart+n-1)) || throw(BoundsError(dest, dstart:dstart+n-1)) (checkbounds(Bool, srcinds, sstart) && checkbounds(Bool, srcinds, sstart+n-1)) || throw(BoundsError(src, sstart:sstart+n-1)) - - gpu_call(linear_copy_kernel!, - dest, dstart, src, sstart, n; - elements=n) + kernel = linear_copy_kernel!(get_backend(dest)) + kernel(dest, dstart, src, sstart, n; ndrange=n) return dest end @@ -242,13 +239,9 @@ end ## generalized blocks of heterogeneous memory -function cartesian_copy_kernel!(ctx::AbstractKernelContext, dest, dest_offsets, src, src_offsets, shape, length) - i = linear_index(ctx) - if i <= length - idx = CartesianIndices(shape)[i] - @inbounds dest[idx + dest_offsets] = src[idx + src_offsets] - end - return +@kernel function cartesian_copy_kernel!(dest, dest_offsets, src, src_offsets) + I = @index(Global, Cartesian) + @inbounds dest[I + dest_offsets] = src[I + src_offsets] end function Base.copyto!(dest::AnyGPUArray{<:Any, N}, destcrange::CartesianIndices{N}, @@ -262,9 +255,8 @@ function Base.copyto!(dest::AnyGPUArray{<:Any, N}, destcrange::CartesianIndices{ dest_offsets = first(destcrange) - oneunit(CartesianIndex{N}) src_offsets = first(srccrange) - oneunit(CartesianIndex{N}) - gpu_call(cartesian_copy_kernel!, - dest, dest_offsets, src, src_offsets, shape, len; - elements=len) + kernel = cartesian_copy_kernel!(get_backend(dest)) + kernel(dest, dest_offsets, src, src_offsets; ndrange=shape) dest end diff --git a/src/host/base.jl b/src/host/base.jl index dc1c73a2..6c58d366 100644 --- a/src/host/base.jl +++ b/src/host/base.jl @@ -4,8 +4,7 @@ import Base: _RepeatInnerOuter # Handle `out = repeat(x; inner)` by parallelizing over `out` array This can benchmark # faster if repeating elements along the first axis (i.e. `inner=(n, ones...)`), as data # access can be contiguous on write. -function repeat_inner_dst_kernel!( - ctx::AbstractKernelContext, +@kernel function repeat_inner_dst_kernel!( xs::AbstractArray{<:Any, N}, inner::NTuple{N, Int}, out::AbstractArray{<:Any, N} @@ -13,27 +12,25 @@ function repeat_inner_dst_kernel!( # Get the "stride" index in each dimension, where the size # of the stride is given by `inner`. The stride-index (sdx) then # corresponds to the index of the repeated value in `xs`. - odx = @cartesianidx out + odx = @index(Global, Cartesian) dest_inds = odx.I sdx = ntuple(N) do i @inbounds (dest_inds[i] - 1) ÷ inner[i] + 1 end @inbounds out[odx] = xs[CartesianIndex(sdx)] - return nothing end # Handle `out = repeat(x; inner)` by parallelizing over the `xs` array This tends to # benchmark faster by having fewer read operations and avoiding the costly division # operation. Additionally, when repeating over the trailing dimension. `inner=(ones..., n)`, # data access can be contiguous during both the read and write operations. -function repeat_inner_src_kernel!( - ctx::AbstractKernelContext, +@kernel function repeat_inner_src_kernel!( xs::AbstractArray{<:Any, N}, inner::NTuple{N, Int}, out::AbstractArray{<:Any, N} ) where {N} # Get single element from src - idx = @cartesianidx xs + idx = @index(Global, Cartesian) @inbounds val = xs[idx] # Loop over "repeat" indices of inner @@ -44,7 +41,6 @@ function repeat_inner_src_kernel!( end @inbounds out[CartesianIndex(odx)] = val end - return nothing end function repeat_inner(xs::AnyGPUArray, inner) @@ -64,23 +60,24 @@ function repeat_inner(xs::AnyGPUArray, inner) # relevant benchmarks. if argmax(inner) == firstindex(inner) # Parallelize over the destination array - gpu_call(repeat_inner_dst_kernel!, xs, inner, out; elements=prod(size(out))) + kernel = repeat_inner_dst_kernel!(get_backend(out)) + kernel(xs, inner, out; ndrange=size(out)) else # Parallelize over the source array - gpu_call(repeat_inner_src_kernel!, xs, inner, out; elements=prod(size(xs))) + kernel = repeat_inner_src_kernel!(get_backend(xs)) + kernel(xs, inner, out; ndrange=size(xs)) end return out end -function repeat_outer_kernel!( - ctx::AbstractKernelContext, +@kernel function repeat_outer_kernel!( xs::AbstractArray{<:Any, N}, xssize::NTuple{N}, outer::NTuple{N}, out::AbstractArray{<:Any, N} ) where {N} # Get index to input element - idx = @cartesianidx xs + idx = @index(Global, Cartesian) @inbounds val = xs[idx] # Loop over repeat indices, copying val to out @@ -91,14 +88,13 @@ function repeat_outer_kernel!( end @inbounds out[CartesianIndex(odx)] = val end - - return nothing end function repeat_outer(xs::AnyGPUArray, outer) out = similar(xs, eltype(xs), outer .* size(xs)) any(==(0), size(out)) && return out # consistent with `Base.repeat` - gpu_call(repeat_outer_kernel!, xs, size(xs), outer, out; elements=length(xs)) + kernel = repeat_outer_kernel!(get_backend(xs)) + kernel(xs, size(xs), outer, out; ndrange=size(xs)) return out end diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index 79aa9759..f5f7847f 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -57,52 +57,39 @@ end typeof(BroadcastStyle(typeof(bc))) end - broadcast_kernel = if ndims(dest) == 1 || - (isa(IndexStyle(dest), IndexLinear) && - isa(IndexStyle(bc), IndexLinear)) - function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) - bc′ = @static if VERSION >= v"1.10-" - Broadcasted(bcstyle, bcf, bcargs, bcaxes) - else - Broadcasted{bcstyle}(bcf, bcargs, bcaxes) - end - - i = 1 - while i <= nelem - I = @linearidx(dest, i) - @inbounds dest[I] = bc′[I] - i += 1 - end - return + @kernel function broadcast_kernel_linear(dest, bcstyle, bcf, bcaxes, bcargs...) + bc′ = @static if VERSION >= v"1.10-" + Broadcasted(bcstyle, bcf, bcargs, bcaxes) + else + Broadcasted{bcstyle}(bcf, bcargs, bcaxes) end - else - function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) - bc′ = @static if VERSION >= v"1.10-" - Broadcasted(bcstyle, bcf, bcargs, bcaxes) - else - Broadcasted{bcstyle}(bcf, bcargs, bcaxes) - end - i = 0 - while i < nelem - i += 1 - I = @cartesianidx(dest, i) - @inbounds dest[I] = bc′[I] - end - return + I = @index(Global, Linear) + @inbounds dest[I] = bc′[I] + end + + @kernel function broadcast_kernel_cartesian(dest, bcstyle, bcf, bcaxes, bcargs...) + bc′ = @static if VERSION >= v"1.10-" + Broadcasted(bcstyle, bcf, bcargs, bcaxes) + else + Broadcasted{bcstyle}(bcf, bcargs, bcaxes) end + + I = @index(Global, Cartesian) + @inbounds dest[I] = bc′[I] end - elements = length(dest) - elements_per_thread = typemax(Int) - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1, - bcstyle, bc.f, bc.axes, bc.args...; - elements, elements_per_thread) - config = launch_configuration(backend(dest), heuristic; - elements, elements_per_thread) - gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int, - bcstyle, bc.f, bc.axes, bc.args...; - threads=config.threads, blocks=config.blocks) + # grid-stride kernel, ndrange set for possible 0D evaluation + if ndims(dest) == 1 || (isa(IndexStyle(dest), IndexLinear) && + isa(IndexStyle(bc), IndexLinear)) + broadcast_kernel_linear(get_backend(dest))(dest, bcstyle, + bc.f, bc.axes, bc.args...; + ndrange = length(size(dest)) > 0 ? length(dest) : 1) + else + broadcast_kernel_cartesian(get_backend(dest))(dest, bcstyle, + bc.f, bc.axes, bc.args...; + ndrange = sz = length(size(dest)) > 0 ? size(dest) : (1,)) + end if eltype(dest) <: BrokenBroadcast throw(ArgumentError("Broadcast operation resulting in $(eltype(eltype(dest))) is not GPU compatible")) @@ -152,27 +139,27 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...) end # grid-stride kernel - function map_kernel(ctx, dest, bc, nelem) - i = 1 - while i <= nelem - j = linear_index(ctx, i) - j > common_length && return + @kernel function map_kernel(dest, bc, nelem, common_length) - J = CartesianIndices(axes(bc))[j] - @inbounds dest[j] = bc[J] + j = 0 + J = @index(Global, Linear) + for i in 1:nelem + j += 1 + if j <= common_length - i += 1 + J_c = CartesianIndices(axes(bc))[(J-1)*nelem + j] + @inbounds dest[J_c] = bc[J_c] + end end - return end elements = common_length elements_per_thread = typemax(Int) - heuristic = launch_heuristic(backend(dest), map_kernel, dest, bc, 1; + heuristic = launch_heuristic(get_backend(dest), map_kernel, dest, bc, 1; elements, elements_per_thread) - config = launch_configuration(backend(dest), heuristic; + config = launch_configuration(get_backend(dest), heuristic; elements, elements_per_thread) - gpu_call(map_kernel, dest, bc, config.elements_per_thread; - threads=config.threads, blocks=config.blocks) + map_kernel(get_backend(dest))(dest, bc, config.elements_per_thread, + common_length; ndrange = config.threads) if eltype(dest) <: BrokenBroadcast throw(ArgumentError("Map operation resulting in $(eltype(eltype(dest))) is not GPU compatible")) diff --git a/src/host/construction.jl b/src/host/construction.jl index d80bce2d..18a3b6d7 100644 --- a/src/host/construction.jl +++ b/src/host/construction.jl @@ -11,29 +11,33 @@ Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractGPUArray} = a isa T function Base.fill!(A::AnyGPUArray{T}, x) where T isempty(A) && return A - gpu_call(A, convert(T, x)) do ctx, a, val - idx = @linearidx(a) + @kernel function fill_kernel!(a, val) + idx = @index(Global, Linear) @inbounds a[idx] = val - return end + + # ndrange set for a possible 0D evaluation + fill_kernel!(get_backend(A))(A, x, + ndrange = length(size(A)) > 0 ? size(A) : (1,)) A end ## identity matrices -function identity_kernel(ctx::AbstractKernelContext, res::AbstractArray{T}, stride, val) where T - i = linear_index(ctx) +@kernel function identity_kernel(res::AbstractArray{T}, stride, val) where T + i = @index(Global, Linear) ilin = (stride * (i - 1)) + i - ilin > length(res) && return - @inbounds res[ilin] = val - return + if ilin <= length(res) + @inbounds res[ilin] = val + end end function (T::Type{<: AnyGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U} res = similar(T, dims) fill!(res, zero(U)) - gpu_call(identity_kernel, res, size(res, 1), s.λ; elements=minimum(dims)) + kernel = identity_kernel(get_backend(res)) + kernel(res, size(res, 1), s.λ; ndrange=minimum(dims)) res end @@ -43,7 +47,8 @@ end function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T fill!(A, zero(T)) - gpu_call(identity_kernel, A, size(A, 1), s.λ; elements=minimum(size(A))) + kernel = identity_kernel(get_backend(A)) + kernel(A, size(A, 1), s.λ; ndrange=minimum(size(A))) A end @@ -52,7 +57,8 @@ function _one(unit::T, x::AbstractGPUMatrix) where {T} m==n || throw(DimensionMismatch("multiplicative identity defined only for square matrices")) I = similar(x, T) fill!(I, zero(T)) - gpu_call(identity_kernel, I, m, unit; elements=m) + kernel = identity_kernel(get_backend(I)) + kernel(I, m, unit; ndrange=m) I end diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 21c0dcdc..98b0780c 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -69,7 +69,7 @@ function vectorized_getindex!(dest::AbstractGPUArray, src::AbstractArray, Is...) Is = map(adapt(ToGPU(dest)), Is) @boundscheck checkbounds(src, Is...) - gpu_call(getindex_kernel, dest, src, idims, Is...) + getindex_kernel(get_backend(dest))(dest, src, idims, Is...; ndrange=size(dest)) return dest end @@ -79,15 +79,19 @@ function vectorized_getindex(src::AbstractGPUArray, Is...) return vectorized_getindex!(dest, src, Is...) end -@generated function getindex_kernel(ctx::AbstractKernelContext, dest, src, idims, - Is::Vararg{Any,N}) where {N} +@kernel function getindex_kernel(dest, src, idims, + Is::Vararg{Any,N}) where {N} + i = @index(Global, Linear) + getindex_generated(dest, src, idims, i, Is...) +end + +@generated function getindex_generated(dest, src, idims, i, + Is::Vararg{Any,N}) where {N} quote - i = @linearidx dest is = @inbounds CartesianIndices(idims)[i] @nexprs $N i -> I_i = @inbounds(Is[i][is[i]]) val = @ncall $N getindex src i -> I_i @inbounds dest[i] = val - return end end @@ -108,15 +112,19 @@ function vectorized_setindex!(dest::AbstractArray, src, Is...) Is = map(adapt(ToGPU(dest)), Is) @boundscheck checkbounds(dest, Is...) - gpu_call(setindex_kernel, dest, adapt(ToGPU(dest), src), idims, len, Is...; - elements=len) + setindex_kernel(get_backend(dest))(dest, adapt(ToGPU(dest), src), idims, len, Is...; + ndrange = length(dest)) return dest end -@generated function setindex_kernel(ctx::AbstractKernelContext, dest, src, idims, len, +@kernel function setindex_kernel(dest, src, idims, len, Is::Vararg{Any,N}) where {N} + i = @index(Global, Linear) + setindex_generated(dest, src, idims, len, i, Is...) +end +@generated function setindex_generated(dest, src, idims, len, i, + Is::Vararg{Any,N}) where {N} quote - i = linear_index(ctx) i > len && return is = @inbounds CartesianIndices(idims)[i] @nexprs $N i -> I_i = @inbounds(Is[i][is[i]]) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 5619dd83..2c7fed4e 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -36,20 +36,20 @@ function LinearAlgebra.transpose!(B::AbstractGPUMatrix, A::AbstractGPUVector) end function LinearAlgebra.adjoint!(B::AbstractGPUVector, A::AbstractGPUMatrix) axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("adjoint")) - gpu_call(B, A) do ctx, B, A - idx = @linearidx B + @kernel function adjoint_kernel!(B, A) + idx = @index(Global, Linear) @inbounds B[idx] = adjoint(A[1, idx]) - return end + adjoint_kernel!(get_backend(B))(B, A, ndrange = size(B)) B end function LinearAlgebra.adjoint!(B::AbstractGPUMatrix, A::AbstractGPUVector) axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("adjoint")) - gpu_call(B, A) do ctx, B, A - idx = @linearidx A + @kernel function adjoint_kernel!(B, A) + idx = @index(Global, Linear) @inbounds B[1, idx] = adjoint(A[idx]) - return end + adjoint_kernel!(get_backend(A))(B, A, ndrange = size(A)) B end @@ -57,11 +57,11 @@ LinearAlgebra.transpose!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(transpos LinearAlgebra.adjoint!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(adjoint, B, A) function transpose_f!(f, B::AnyGPUMatrix{T}, A::AnyGPUMatrix{T}) where T axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) || throw(DimensionMismatch(string(f))) - gpu_call(B, A) do ctx, B, A - idx = @cartesianidx A + @kernel function transpose_kernel!(B, A) + idx = @index(Global, Cartesian) @inbounds B[idx[2], idx[1]] = f(A[idx[1], idx[2]]) - return end + transpose_kernel!(get_backend(B))(B, A, ndrange = size(A)) B end @@ -82,48 +82,48 @@ end ## copy upper triangle to lower and vice versa -function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool=false) - n = LinearAlgebra.checksquare(A) - if uplo == 'U' && conjugate - gpu_call(A) do ctx, _A - I = @cartesianidx _A - i, j = Tuple(I) - if j > i - @inbounds _A[j,i] = conj(_A[i,j]) +function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, conjugate::Bool=false) where T + n = LinearAlgebra.checksquare(A) + if uplo == 'U' && conjugate + @kernel function U_conj!(_A) + I = @index(Global, Cartesian) + i, j = Tuple(I) + if j > i + @inbounds _A[j,i] = conj(_A[i,j]) + end end - return - end - elseif uplo == 'U' && !conjugate - gpu_call(A) do ctx, _A - I = @cartesianidx _A - i, j = Tuple(I) - if j > i - @inbounds _A[j,i] = _A[i,j] + U_conj!(get_backend(A))(A, ndrange = size(A)) + elseif uplo == 'U' && !conjugate + @kernel function U_noconj!(_A) + I = @index(Global, Cartesian) + i, j = Tuple(I) + if j > i + @inbounds _A[j,i] = _A[i,j] + end end - return - end - elseif uplo == 'L' && conjugate - gpu_call(A) do ctx, _A - I = @cartesianidx _A - i, j = Tuple(I) - if j > i - @inbounds _A[i,j] = conj(_A[j,i]) + U_noconj!(get_backend(A))(A, ndrange = size(A)) + elseif uplo == 'L' && conjugate + @kernel function L_conj!(_A) + I = @index(Global, Cartesian) + i, j = Tuple(I) + if j > i + @inbounds _A[i,j] = conj(_A[j,i]) + end end - return - end - elseif uplo == 'L' && !conjugate - gpu_call(A) do ctx, _A - I = @cartesianidx _A - i, j = Tuple(I) - if j > i - @inbounds _A[i,j] = _A[j,i] + L_conj!(get_backend(A))(A, ndrange = size(A)) + elseif uplo == 'L' && !conjugate + @kernel function L_noconj!(_A) + I = @index(Global, Cartesian) + i, j = Tuple(I) + if j > i + @inbounds _A[i,j] = _A[j,i] + end end - return - end - else - throw(ArgumentError("uplo argument must be 'U' (upper) or 'L' (lower), got $uplo")) - end - A + L_noconj!(get_backend(A))(A, ndrange = size(A)) + else + throw(ArgumentError("uplo argument must be 'U' (upper) or 'L' (lower), got $uplo")) + end + A end ## copy a triangular part of a matrix to another matrix @@ -135,23 +135,23 @@ if isdefined(LinearAlgebra, :copytrito!) m1,n1 = size(B) (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)")) if uplo == 'U' - gpu_call(A, B) do ctx, _A, _B - I = @cartesianidx _A + @kernel function U_kernel!(_A, _B) + I = @index(Global, Cartesian) i, j = Tuple(I) if j >= i @inbounds _B[i,j] = _A[i,j] end - return end + U_kernel!(get_backend(B))(A, B, ndrange = size(A)) else # uplo == 'L' - gpu_call(A, B) do ctx, _A, _B - I = @cartesianidx _A + @kernel function L_kernel!(_A, _B) + I = @index(Global, Cartesian) i, j = Tuple(I) if j <= i @inbounds _B[i,j] = _A[i,j] end - return end + L_kernel!(get_backend(A))(A, B, ndrange = size(A)) end return B end @@ -171,26 +171,26 @@ for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriang end function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T - gpu_call(A, d; name="tril!") do ctx, _A, _d - I = @cartesianidx _A + @kernel function tril_kernel!(_A, _d) + I = @index(Global, Cartesian) i, j = Tuple(I) if i < j - _d @inbounds _A[i, j] = zero(T) end - return end + tril_kernel!(get_backend(A))(A, d, ndrange = size(A)) return A end function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T - gpu_call(A, d; name="triu!") do ctx, _A, _d - I = @cartesianidx _A + @kernel function triu_kernel!(_A, _d) + I = @index(Global, Cartesian) i, j = Tuple(I) if j < i + _d @inbounds _A[i, j] = zero(T) end - return end + triu_kernel!(get_backend(A))(A, d, ndrange = size(A)) return A end @@ -352,9 +352,9 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac return fill!(C, zero(R)) end - gpu_call(C, A, B; name="matmatmul!") do ctx, C, A, B - idx = @linearidx C + @kernel function matmatmul_kernel!(C, A, B) assume.(size(C) .> 0) + idx = @index(Global, Linear) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 @inbounds if i <= size(A,1) && j <= size(B,2) @@ -365,10 +365,8 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac end C[i,j] = add(Cij, C[i,j]) end - - return end - + matmatmul_kernel!(get_backend(C))(C, A, B, ndrange = size(C)) C end @@ -404,8 +402,8 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun upper = tfun === identity ? uploc == 'U' : uploc != 'U' unit = isunitc == 'U' - function trimatmul(ctx, C, A, B) - idx = @linearidx C + @kernel function trimatmul(C, A, B) + idx = @index(Global, Linear) assume.(size(C) .> 0) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 l, m, n = size(A, 1), size(B, 1), size(B, 2) @@ -419,12 +417,10 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun end C[i,j] += Cij end - - return end - function trimatmul_t(ctx, C, A, B) - idx = @linearidx C + @kernel function trimatmul_t(C, A, B) + idx = @index(Global, Linear) assume.(size(C) .> 0) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 l, m, n = size(A, 1), size(B, 1), size(B, 2) @@ -438,12 +434,10 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun end C[i,j] += Cij end - - return end - function trimatmul_a(ctx, C, A, B) - idx = @linearidx C + @kernel function trimatmul_a(C, A, B) + idx = @index(Global, Linear) assume.(size(C) .> 0) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 l, m, n = size(A, 1), size(B, 1), size(B, 2) @@ -457,16 +451,14 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun end C[i,j] += Cij end - - return end if tfun === identity - gpu_call(trimatmul, C, A, B; name="trimatmul") + trimatmul(get_backend(C))(C, A, B, ndrange = length(C)) elseif tfun == transpose - gpu_call(trimatmul_t, C, A, B; name="trimatmul_t") + trimatmul_t(get_backend(C))(C, A, B, ndrange = length(C)) elseif tfun === adjoint - gpu_call(trimatmul_a, C, A, B; name="trimatmul_a") + trimatmul_a(get_backend(C))(C, A, B, ndrange = length(C)) else error("Not supported") end @@ -488,8 +480,8 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun upper = tfun === identity ? uploc == 'U' : uploc != 'U' unit = isunitc == 'U' - function mattrimul(ctx, C, A, B) - idx = @linearidx C + @kernel function mattrimul(C, A, B) + idx = @index(Global, Linear) assume.(size(C) .> 0) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 l, m, n = size(A, 1), size(B, 1), size(B, 2) @@ -503,12 +495,10 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun end C[i,j] += Cij end - - return end - function mattrimul_t(ctx, C, A, B) - idx = @linearidx C + @kernel function mattrimul_t(C, A, B) + idx = @index(Global, Linear) assume.(size(C) .> 0) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 l, m, n = size(A, 1), size(B, 1), size(B, 2) @@ -522,12 +512,10 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun end C[i,j] += Cij end - - return end - function mattrimul_a(ctx, C, A, B) - idx = @linearidx C + @kernel function mattrimul_a(C, A, B) + idx = @index(Global, Linear) assume.(size(C) .> 0) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 l, m, n = size(A, 1), size(B, 1), size(B, 2) @@ -541,16 +529,14 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun end C[i,j] += Cij end - - return end if tfun === identity - gpu_call(mattrimul, C, A, B; name="mattrimul") + mattrimul(get_backend(C))(C, A, B, ndrange = length(C)) elseif tfun == transpose - gpu_call(mattrimul_t, C, A, B; name="mattrimul_t") + mattrimul_t(get_backend(C))(C, A, B, ndrange = length(C)) elseif tfun === adjoint - gpu_call(mattrimul_a, C, A, B; name="mattrimul_a") + mattrimul_a(get_backend(C))(C, A, B, ndrange = length(C)) else error("Not supported") end @@ -600,22 +586,22 @@ end end # VERSION function generic_rmul!(X::AbstractArray, s::Number) - gpu_call(X, s; name="rmul!") do ctx, X, s - i = @linearidx X + @kernel function rmul_kernel!(X, s) + i = @index(Global, Linear) @inbounds X[i] *= s - return end + rmul_kernel!(get_backend(X))(X, s, ndrange = size(X)) return X end LinearAlgebra.rmul!(A::AbstractGPUArray, b::Number) = generic_rmul!(A, b) function generic_lmul!(s::Number, X::AbstractArray) - gpu_call(X, s; name="lmul!") do ctx, X, s - i = @linearidx X + @kernel function lmul_kernel!(X, s) + i = @index(Global, Linear) @inbounds X[i] = s*X[i] - return end + lmul_kernel!(get_backend(X))(X, s, ndrange = size(X)) return X end @@ -657,15 +643,16 @@ function _permutedims!(::Type{IT}, dest::AbstractGPUArray, dest_strides = ntuple(k->k==1 ? 1 : prod(i->size(dest, i), 1:k-1), N) dest_strides_perm = ntuple(i->IT(dest_strides[findfirst(==(i), perm)]), N) size_src = IT.(size(src)) - function permutedims_kernel(ctx, dest, src, size_src, dest_strides_perm) - SLI = @linearidx dest + @kernel function permutedims_kernel!(dest, src, size_src, dest_strides_perm) + SLI = @index(Global, Linear) assume(0 < SLI <= typemax(IT)) LI = IT(SLI) dest_index = permute_linearindex(size_src, LI, dest_strides_perm) @inbounds dest[dest_index] = src[LI] - return end - gpu_call(permutedims_kernel, vec(dest), vec(src), size_src, dest_strides_perm) + permutedims_kernel!(get_backend(dest))(vec(dest), vec(src), size_src, + dest_strides_perm, + ndrange = size(dest)) return dest end @@ -742,28 +729,28 @@ end ## rotate function LinearAlgebra.rotate!(x::AbstractGPUArray, y::AbstractGPUArray, c::Number, s::Number) - gpu_call(x, y, c, s; name="rotate!") do ctx, x, y, c, s - i = @linearidx x + @kernel function rotate_kernel!(x, y, c, s) + i = @index(Global, Linear) @inbounds xi = x[i] @inbounds yi = y[i] @inbounds x[i] = c * xi + s * yi @inbounds y[i] = -conj(s) * xi + c * yi - return end + rotate_kernel!(get_backend(x))(x, y, c, s, ndrange = size(x)) return x, y end ## reflect function LinearAlgebra.reflect!(x::AbstractGPUArray, y::AbstractGPUArray, c::Number, s::Number) - gpu_call(x, y, c, s; name="reflect!") do ctx, x, y, c, s - i = @linearidx x + @kernel function reflect_kernel!(x, y, c, s) + i = @index(Global, Linear) @inbounds xi = x[i] @inbounds yi = y[i] @inbounds x[i] = c * xi + s * yi @inbounds y[i] = conj(s) * xi - c * yi - return end + reflect_kernel!(get_backend(x))(x, y, c, s, ndrange = size(x)) return x, y end diff --git a/src/host/math.jl b/src/host/math.jl index cf455d31..f96fb8ed 100644 --- a/src/host/math.jl +++ b/src/host/math.jl @@ -1,10 +1,10 @@ # Base mathematical operations function Base.clamp!(A::AnyGPUArray, low, high) - gpu_call(A, low, high) do ctx, A, low, high - I = @linearidx A + @kernel function clamp_kernel!(A, low, high) + I = @index(Global, Cartesian) A[I] = clamp(A[I], low, high) - return end + clamp_kernel!(get_backend(A))(A, low, high, ndrange = size(A)) return A end diff --git a/src/host/random.jl b/src/host/random.jl index b7e5dc74..2112a8ed 100644 --- a/src/host/random.jl +++ b/src/host/random.jl @@ -30,15 +30,13 @@ function next_rand(state::NTuple{4, T}) where {T <: Unsigned} return state, tmp end -function gpu_rand(::Type{T}, ctx::AbstractKernelContext, randstate::AbstractVector{NTuple{4, UInt32}}) where T - threadid = GPUArrays.threadidx(ctx) +function gpu_rand(::Type{T}, threadid, randstate::AbstractVector{NTuple{4, UInt32}}) where T stateful_rand = next_rand(randstate[threadid]) randstate[threadid] = stateful_rand[1] return make_rand_num(T, stateful_rand[2]) end -function gpu_rand(::Type{T}, ctx::AbstractKernelContext, randstate::AbstractVector{NTuple{4, UInt32}}) where T <: Integer - threadid = GPUArrays.threadidx(ctx) +function gpu_rand(::Type{T}, threadid, randstate::AbstractVector{NTuple{4, UInt32}}) where T <: Integer result = zero(T) if sizeof(T) >= 4 for _ in 1:sizeof(T) >> 2 @@ -55,9 +53,9 @@ end # support for complex numbers -function gpu_rand(::Type{Complex{T}}, ctx::AbstractKernelContext, randstate::AbstractVector{NTuple{4, UInt32}}) where T - re = gpu_rand(T, ctx, randstate) - im = gpu_rand(T, ctx, randstate) +function gpu_rand(::Type{Complex{T}}, threadid, randstate::AbstractVector{NTuple{4, UInt32}}) where T + re = gpu_rand(T, threadid, randstate) + im = gpu_rand(T, threadid, randstate) return complex(re, im) end @@ -85,29 +83,31 @@ end function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number isempty(A) && return A - gpu_call(A, rng.state) do ctx, a, randstates - idx = linear_index(ctx) - idx > length(a) && return - @inbounds a[idx] = gpu_rand(T, ctx, randstates) - return + @kernel function rand!(a, randstate) + idx = @index(Global, Linear) + @inbounds a[idx] = gpu_rand(T, ((idx-1)%length(randstate)+1), randstate) end + rand!(get_backend(A))(A, rng.state, ndrange = size(A)) A end function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number isempty(A) && return A threads = (length(A) - 1) ÷ 2 + 1 - gpu_call(A, rng.state; elements = threads) do ctx, a, randstates - idx = 2*(linear_index(ctx) - 1) + 1 - U1 = gpu_rand(T, ctx, randstates) - U2 = gpu_rand(T, ctx, randstates) + @kernel function randn!(a, randstates) + i = @index(Global, Linear) + idx = 2*(i - 1) + 1 + U1 = gpu_rand(T, i, randstates) + U2 = gpu_rand(T, i, randstates) Z0 = sqrt(T(-2.0)*log(U1))*cos(T(2pi)*U2) Z1 = sqrt(T(-2.0)*log(U1))*sin(T(2pi)*U2) @inbounds a[idx] = Z0 - idx + 1 > length(a) && return - @inbounds a[idx + 1] = Z1 - return + if idx + 1 <= length(a) + @inbounds a[idx + 1] = Z1 + end end + kernel = randn!(get_backend(A)) + kernel(A, rng.state; ndrange=threads) A end diff --git a/src/host/uniformscaling.jl b/src/host/uniformscaling.jl index 848eef5e..f8f8ae5a 100644 --- a/src/host/uniformscaling.jl +++ b/src/host/uniformscaling.jl @@ -12,20 +12,16 @@ const unittriangularwrappers = ( (:UnitLowerTriangular, :LowerTriangular) ) -function kernel_generic(ctx, B, J, min_size) - lin_idx = linear_index(ctx) - lin_idx > min_size && return nothing +@kernel function kernel_generic(B, J) + lin_idx = @index(Global, Linear) @inbounds diag_idx = diagind(B)[lin_idx] @inbounds B[diag_idx] += J - return nothing end -function kernel_unittriangular(ctx, B, J, diagonal_val, min_size) - lin_idx = linear_index(ctx) - lin_idx > min_size && return nothing +@kernel function kernel_unittriangular(B, J, diagonal_val) + lin_idx = @index(Global, Linear) @inbounds diag_idx = diagind(B)[lin_idx] @inbounds B[diag_idx] = diagonal_val + J - return nothing end for (t1, t2) in unittriangularwrappers @@ -34,7 +30,7 @@ for (t1, t2) in unittriangularwrappers B = similar(parent(A), typeof(oneunit(T) + J)) copyto!(B, parent(A)) min_size = minimum(size(B)) - gpu_call(kernel_unittriangular, B, J, one(eltype(B)), min_size; elements=min_size) + kernel_unittriangular(get_backend(B))(B, J, one(eltype(B)); ndrange=min_size) return $t2(B) end @@ -42,7 +38,7 @@ for (t1, t2) in unittriangularwrappers B = similar(parent(A), typeof(J - oneunit(T))) B .= .- parent(A) min_size = minimum(size(B)) - gpu_call(kernel_unittriangular, B, J, -one(eltype(B)), min_size; elements=min_size) + kernel_unittriangular(get_backend(B))(B, J, -one(eltype(B)); ndrange=min_size) return $t2(B) end end @@ -54,7 +50,7 @@ for t in genericwrappers B = similar(parent(A), typeof(oneunit(T) + J)) copyto!(B, parent(A)) min_size = minimum(size(B)) - gpu_call(kernel_generic, B, J, min_size; elements=min_size) + kernel_generic(get_backend(B))(B, J; ndrange=min_size) return $t(B) end @@ -62,7 +58,7 @@ for t in genericwrappers B = similar(parent(A), typeof(J - oneunit(T))) B .= .- parent(A) min_size = minimum(size(B)) - gpu_call(kernel_generic, B, J, min_size; elements=min_size) + kernel_generic(get_backend(B))(B, J; ndrange=min_size) return $t(B) end end @@ -73,7 +69,7 @@ function (+)(A::Hermitian{T,<:AbstractGPUMatrix}, J::UniformScaling{<:Complex}) B = similar(parent(A), typeof(oneunit(T) + J)) copyto!(B, parent(A)) min_size = minimum(size(B)) - gpu_call(kernel_generic, B, J, min_size; elements=min_size) + kernel_generic(get_backend(B))(B, J; ndrange=min_size) return B end @@ -81,7 +77,7 @@ function (-)(J::UniformScaling{<:Complex}, A::Hermitian{T,<:AbstractGPUMatrix}) B = similar(parent(A), typeof(J - oneunit(T))) B .= .-parent(A) min_size = minimum(size(B)) - gpu_call(kernel_generic, B, J, min_size; elements=min_size) + kernel_generic(get_backend(B))(B, J; ndrange=min_size) return B end @@ -90,7 +86,7 @@ function (+)(A::AbstractGPUMatrix{T}, J::UniformScaling) where T B = similar(A, typeof(oneunit(T) + J)) copyto!(B, A) min_size = minimum(size(B)) - gpu_call(kernel_generic, B, J, min_size; elements=min_size) + kernel_generic(get_backend(B))(B, J; ndrange=min_size) return B end @@ -98,6 +94,6 @@ function (-)(J::UniformScaling, A::AbstractGPUMatrix{T}) where T B = similar(A, typeof(J - oneunit(T))) B .= .-A min_size = minimum(size(B)) - gpu_call(kernel_generic, B, J, min_size; elements=min_size) + kernel_generic(get_backend(B))(B, J; ndrange=min_size) return B end diff --git a/test/Project.toml b/test/Project.toml index 76e1e22a..eb59ac76 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/runtests.jl b/test/runtests.jl index 4df72b2b..2cfdb0c6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,22 @@ using GPUArrays, Test, Pkg include("testsuite.jl") +@testset "JLArray" begin + # install the JLArrays subpackage in a temporary environment + old_project = Base.active_project() + Pkg.activate(; temp=true) + Pkg.develop(path=joinpath(dirname(@__DIR__), "lib", "JLArrays")) + + using JLArrays + + jl([1]) + + TestSuite.test(JLArray) + + Pkg.activate(old_project) +end + +#= @testset "JLArray" begin using JLArrays @@ -9,6 +25,7 @@ include("testsuite.jl") TestSuite.test(JLArray) end +=# @testset "Array" begin TestSuite.test(Array) diff --git a/test/testsuite.jl b/test/testsuite.jl index e7c14646..b939d2e9 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -8,6 +8,7 @@ export supported_eltypes using GPUArrays +using KernelAbstractions using LinearAlgebra using Random using Test @@ -85,7 +86,6 @@ macro testsuite(name, ex) end include("testsuite/construction.jl") -include("testsuite/gpuinterface.jl") include("testsuite/indexing.jl") include("testsuite/base.jl") include("testsuite/vector.jl") diff --git a/test/testsuite/base.jl b/test/testsuite/base.jl index 2fd0f417..5a58d3d2 100644 --- a/test/testsuite/base.jl +++ b/test/testsuite/base.jl @@ -1,28 +1,23 @@ -function cartesian_iter(state, res, A, Asize) - for i in CartesianIndices(Asize) - res[i] = A[i] - end - return +@kernel function cartesian_iter(res, A) + i = @index(Global, Cartesian) + res[i] = A[i] end -function clmap!(ctx, f, out, b) - i = linear_index(ctx) # get the kernel index it gets scheduled on +@kernel function clmap!(f, out, b) + i = @index(Global, Linear) # get the kernel index it gets scheduled on out[i] = f(b[i]) - return end -function ntuple_test(ctx, result, ::Val{N}) where N +@kernel function ntuple_test(result, ::Val{N}) where N result[1] = ntuple(Val(N)) do i Float32(i) * 77f0 end - return end -function ntuple_closure(ctx, result, ::Val{N}, testval) where N +@kernel function ntuple_closure(result, ::Val{N}, testval) where N result[1] = ntuple(Val(N)) do i Float32(i) * testval end - return end @testsuite "base" (AT, eltypes)->begin @@ -174,10 +169,10 @@ end AT <: AbstractGPUArray && @testset "ntuple test" begin result = AT(Vector{NTuple{3, Float32}}(undef, 1)) - gpu_call(ntuple_test, result, Val(3)) + ntuple_test(get_backend(result))(result, Val(3); ndrange = 1) @test Array(result)[1] == (77, 2*77, 3*77) x = 88f0 - gpu_call(ntuple_closure, result, Val(3), x) + ntuple_closure(get_backend(result))(result, Val(3), x; ndrange = 1) @test Array(result)[1] == (x, 2*x, 3*x) end @@ -185,14 +180,14 @@ end Ac = rand(Float32, 32, 32) A = AT(Ac) result = fill!(copy(A), 0.0f0) - gpu_call(cartesian_iter, result, A, size(A)) + cartesian_iter(get_backend(A))(result, A; ndrange = size(A)) Array(result) == Ac end AT <: AbstractGPUArray && @testset "Custom kernel from Julia function" begin x = AT(rand(Float32, 100)) y = AT(rand(Float32, 100)) - gpu_call(clmap!, -, x, y; target=x) + clmap!(get_backend(x))(-, x, y; ndrange = size(x)) jy = Array(y) @test map!(-, jy, jy) ≈ Array(x) end diff --git a/test/testsuite/broadcasting.jl b/test/testsuite/broadcasting.jl index b856eb0f..81b028f3 100644 --- a/test/testsuite/broadcasting.jl +++ b/test/testsuite/broadcasting.jl @@ -200,8 +200,9 @@ Base.size(A::WrapArray) = size(A.data) # For kernal support Adapt.adapt_structure(to, s::WrapArray) = WrapArray(Adapt.adapt(to, s.data)) # For broadcast support -GPUArrays.backend(::Type{WrapArray{T,N,P}}) where {T,N,P} = GPUArrays.backend(P) Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P) +KernelAbstractions.get_backend(a::WA) where WA <: WrapArray = get_backend(a.data) + function unknown_wrapper(AT, eltypes) for ET in eltypes diff --git a/test/testsuite/gpuinterface.jl b/test/testsuite/gpuinterface.jl deleted file mode 100644 index 1455c732..00000000 --- a/test/testsuite/gpuinterface.jl +++ /dev/null @@ -1,47 +0,0 @@ -@testsuite "interface" (AT, eltypes)->begin - AT <: AbstractGPUArray || return - - N = 10 - x = AT(Vector{Int}(undef, N)) - x .= 0 - gpu_call(x) do ctx, x - x[linear_index(ctx)] = 2 - return - end - @test all(x-> x == 2, Array(x)) - - gpu_call(x; elements=N) do ctx, x - x[linear_index(ctx)] = 2 - return - end - @test all(x-> x == 2, Array(x)) - gpu_call(x; threads=2, blocks=(N ÷ 2)) do ctx, x - x[linear_index(ctx)] = threadidx(ctx) - return - end - @test Array(x) == [1,2,1,2,1,2,1,2,1,2] - - gpu_call(x; threads=2, blocks=(N ÷ 2)) do ctx, x - x[linear_index(ctx)] = blockidx(ctx) - return - end - @test Array(x) == [1, 1, 2, 2, 3, 3, 4, 4, 5, 5] - x2 = AT([0]) - gpu_call(x2; threads=2, blocks=(N ÷ 2), target=x) do ctx, x - x[1] = blockdim(ctx) - return - end - @test Array(x2) == [2] - - gpu_call(x2; threads=2, blocks=(N ÷ 2), target=x) do ctx, x - x[1] = griddim(ctx) - return - end - @test Array(x2) == [5] - - gpu_call(x2; threads=2, blocks=(N ÷ 2), target=x) do ctx, x - x[1] = global_size(ctx) - return - end - @test Array(x2) == [10] -end