diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py index 86506fdb42..93f87ca3ca 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py @@ -54,7 +54,7 @@ def _intrinsic_spirv_global_index_const( sig = types.int64(types.int32) def _intrinsic_spirv_global_index_const_gen( - context: SPIRVTargetContext, + context: SPIRVTargetContext, # pylint: disable=unused-argument builder: llvmir.IRBuilder, sig, # pylint: disable=unused-argument args, @@ -79,7 +79,16 @@ def _intrinsic_spirv_global_index_const_gen( dim, ) - return context.cast(builder, res, types.uintp, types.intp) + # Generating same check as sycl does. Did they add it to avoid pointer + # bitcast on special constant? + max_int32 = llvmir.Constant(res.type, 2147483648) + cmp = builder.icmp_unsigned("<", res, max_int32) + + inst = builder.assume(cmp) + # TODO: tail does not always work + inst.tail = "tail" + + return res return sig, _intrinsic_spirv_global_index_const_gen diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_private_array_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_private_array_overloads.py index 66f7b7e835..958387c8df 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_private_array_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_private_array_overloads.py @@ -9,10 +9,11 @@ import llvmlite.ir as llvmir from llvmlite.ir.builder import IRBuilder +from numba.core import cgutils, types from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype from numba.core.typing.npydecl import parse_shape as _ty_parse_shape from numba.core.typing.templates import Signature -from numba.extending import intrinsic, overload +from numba.extending import type_callable from numba_dpex.core.types import USMNdArray from numba_dpex.experimental.target import DpexExpKernelTypingContext @@ -23,55 +24,12 @@ ) from numba_dpex.utils import address_space as AddressSpace -from ..target import DPEX_KERNEL_EXP_TARGET_NAME +from ._registry import lower -@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) -def _intrinsic_private_array_ctor( - ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument -): - require_literal(ty_shape) - - ty_array = USMNdArray( - dtype=_ty_parse_dtype(ty_dtype), - ndim=_ty_parse_shape(ty_shape), - layout="C", - addrspace=AddressSpace.PRIVATE, - ) - - sig = ty_array(ty_shape, ty_dtype) - - def codegen( - context: DpexExpKernelTypingContext, - builder: IRBuilder, - sig: Signature, - args: list[llvmir.Value], - ): - shape = args[0] - ty_shape = sig.args[0] - ty_array = sig.return_type - - ary = make_spirv_generic_array_on_stack( - context, builder, ty_array, ty_shape, shape - ) - return ary._getvalue() # pylint: disable=protected-access - - return ( - sig, - codegen, - ) - - -@overload( - PrivateArray, - prefer_literal=True, - target=DPEX_KERNEL_EXP_TARGET_NAME, -) -def ol_private_array_ctor( - shape, - dtype, -): - """Overload of the constructor for the class +@type_callable(PrivateArray) +def type_interval(context): # pylint: disable=unused-argument + """Sets type of the constructor for the class class:`numba_dpex.kernel_api.PrivateArray`. Raises: @@ -81,11 +39,48 @@ def ol_private_array_ctor( type. """ - def ol_private_array_ctor_impl( - shape, - dtype, - ): - # pylint: disable=no-value-for-parameter - return _intrinsic_private_array_ctor(shape, dtype) + def typer(shape, dtype, fill_zeros=types.BooleanLiteral(False)): + require_literal(shape) + require_literal(fill_zeros) + + return USMNdArray( + dtype=_ty_parse_dtype(dtype), + ndim=_ty_parse_shape(shape), + layout="C", + addrspace=AddressSpace.PRIVATE, + ) + + return typer + + +@lower(PrivateArray, types.IntegerLiteral, types.Any, types.BooleanLiteral) +@lower(PrivateArray, types.Tuple, types.Any, types.BooleanLiteral) +@lower(PrivateArray, types.UniTuple, types.Any, types.BooleanLiteral) +@lower(PrivateArray, types.IntegerLiteral, types.Any) +@lower(PrivateArray, types.Tuple, types.Any) +@lower(PrivateArray, types.UniTuple, types.Any) +def dpex_private_array_lower( + context: DpexExpKernelTypingContext, + builder: IRBuilder, + sig: Signature, + args: list[llvmir.Value], +): + """Implements lower for the class:`numba_dpex.kernel_api.PrivateArray`""" + shape = args[0] + ty_shape = sig.args[0] + if len(sig.args) == 3: + fill_zeros = sig.args[-1].literal_value + else: + fill_zeros = False + ty_array = sig.return_type + + ary = make_spirv_generic_array_on_stack( + context, builder, ty_array, ty_shape, shape + ) + + if fill_zeros: + cgutils.memset( + builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0 + ) - return ol_private_array_ctor_impl + return ary._getvalue() # pylint: disable=protected-access diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_registry.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_registry.py new file mode 100644 index 0000000000..1fae06a258 --- /dev/null +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_registry.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: 2024 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Implements the SPIR-V overloads for the kernel_api.PrivateArray class. +""" + +from numba.core.imputils import Registry + +registry = Registry() +lower = registry.lower diff --git a/numba_dpex/kernel_api/private_array.py b/numba_dpex/kernel_api/private_array.py index 7393cc71b8..95b9a7ae2a 100644 --- a/numba_dpex/kernel_api/private_array.py +++ b/numba_dpex/kernel_api/private_array.py @@ -7,7 +7,7 @@ kernel function. """ -from numpy import ndarray +import numpy as np class PrivateArray: @@ -16,10 +16,13 @@ class PrivateArray: inside kernel work item. """ - def __init__(self, shape, dtype) -> None: + def __init__(self, shape, dtype, fill_zeros=False) -> None: """Creates a new PrivateArray instance of the given shape and dtype.""" - self._data = ndarray(shape=shape, dtype=dtype) + if fill_zeros: + self._data = np.zeros(shape=shape, dtype=dtype) + else: + self._data = np.empty(shape=shape, dtype=dtype) def __getitem__(self, idx_obj): """Returns the value stored at the position represented by idx_obj in diff --git a/numba_dpex/kernel_api_impl/spirv/arrayobj.py b/numba_dpex/kernel_api_impl/spirv/arrayobj.py index e1e5742b28..325d0e4e18 100644 --- a/numba_dpex/kernel_api_impl/spirv/arrayobj.py +++ b/numba_dpex/kernel_api_impl/spirv/arrayobj.py @@ -41,7 +41,9 @@ def require_literal(literal_type: types.Type): for i, _ in enumerate(literal_type): if not isinstance(literal_type[i], types.Literal): - raise errors.TypingError("requires literal type") + raise errors.TypingError( + "requires each element of tuple literal type" + ) def make_spirv_array( # pylint: disable=too-many-arguments diff --git a/numba_dpex/kernel_api_impl/spirv/dispatcher.py b/numba_dpex/kernel_api_impl/spirv/dispatcher.py index 056c9ffc11..9aac39edb4 100644 --- a/numba_dpex/kernel_api_impl/spirv/dispatcher.py +++ b/numba_dpex/kernel_api_impl/spirv/dispatcher.py @@ -5,6 +5,7 @@ """Implements a new numba dispatcher class and a compiler class to compile and call numba_dpex.kernel decorated function. """ +import hashlib from collections import namedtuple from contextlib import ExitStack from typing import Tuple @@ -181,6 +182,9 @@ def _compile_to_spirv( # all linking libraries getting linked together and final optimization # including inlining of functions if an inlining level is specified. kernel_library.finalize() + + if config.DUMP_KERNEL_LLVM: + self._dump_kernel(kernel_fndesc, kernel_library) # Compiled the LLVM IR to SPIR-V kernel_spirv_module = spirv_generator.llvm_to_spirv( kernel_targetctx, @@ -268,20 +272,26 @@ def _compile_cached( kcres_attrs.append(kernel_device_ir_module) - if config.DUMP_KERNEL_LLVM: - with open( - cres.fndesc.llvm_func_name + ".ll", - "w", - encoding="UTF-8", - ) as fptr: - fptr.write(str(cres.library.final_module)) - except errors.TypingError as err: self._failed_cache[key] = err return False, err return True, _SPIRVKernelCompileResult(*kcres_attrs) + def _dump_kernel(self, fndesc, library): + """Dump kernel into file.""" + name = fndesc.llvm_func_name + if len(name) > 200: + sha256 = hashlib.sha256(name.encode("utf-8")).hexdigest() + name = name[:150] + "_" + sha256 + + with open( + name + ".ll", + "w", + encoding="UTF-8", + ) as fptr: + fptr.write(str(library.final_module)) + class SPIRVKernelDispatcher(Dispatcher): """Dispatcher class designed to compile kernel decorated functions. The diff --git a/numba_dpex/kernel_api_impl/spirv/spirv_generator.py b/numba_dpex/kernel_api_impl/spirv/spirv_generator.py index 1171faca4a..c731a9c75d 100644 --- a/numba_dpex/kernel_api_impl/spirv/spirv_generator.py +++ b/numba_dpex/kernel_api_impl/spirv/spirv_generator.py @@ -123,6 +123,7 @@ def finalize(self): llvm_spirv_args = [ "--spirv-ext=+SPV_EXT_shader_atomic_float_add", "--spirv-ext=+SPV_EXT_shader_atomic_float_min_max", + "--spirv-ext=+SPV_INTEL_arbitrary_precision_integers", ] for key in list(self.context.extra_compile_options.keys()): if key == LLVM_SPIRV_ARGS: diff --git a/numba_dpex/kernel_api_impl/spirv/target.py b/numba_dpex/kernel_api_impl/spirv/target.py index 4e51b9b8fd..4a1b4a42e2 100644 --- a/numba_dpex/kernel_api_impl/spirv/target.py +++ b/numba_dpex/kernel_api_impl/spirv/target.py @@ -383,12 +383,16 @@ def load_additional_registries(self): # pylint: disable=import-outside-toplevel from numba_dpex import printimpl from numba_dpex.dpnp_iface import dpnpimpl + from numba_dpex.experimental._kernel_dpcpp_spirv_overloads._registry import ( + registry as spirv_registry, + ) from numba_dpex.ocl import mathimpl, oclimpl self.insert_func_defn(oclimpl.registry.functions) self.insert_func_defn(mathimpl.registry.functions) self.insert_func_defn(dpnpimpl.registry.functions) self.install_registry(printimpl.registry) + self.install_registry(spirv_registry) # Replace dpnp math functions with their OpenCL versions. self.replace_dpnp_ufunc_with_ocl_intrinsics() diff --git a/numba_dpex/tests/experimental/test_private_array.py b/numba_dpex/tests/experimental/test_private_array.py index fcbf69b825..fa6af6f58b 100644 --- a/numba_dpex/tests/experimental/test_private_array.py +++ b/numba_dpex/tests/experimental/test_private_array.py @@ -23,6 +23,30 @@ def private_array_kernel(item: Item, a): a[i] += p[j] +def private_array_kernel_fill_true(item: Item, a): + i = item.get_linear_id() + p = PrivateArray(10, a.dtype, fill_zeros=True) + + for j in range(10): + p[j] = j * j + + a[i] = 0 + for j in range(10): + a[i] += p[j] + + +def private_array_kernel_fill_false(item: Item, a): + i = item.get_linear_id() + p = PrivateArray(10, a.dtype, fill_zeros=False) + + for j in range(10): + p[j] = j * j + + a[i] = 0 + for j in range(10): + a[i] += p[j] + + def private_2d_array_kernel(item: Item, a): i = item.get_linear_id() p = PrivateArray(shape=(5, 2), dtype=a.dtype) @@ -36,7 +60,13 @@ def private_2d_array_kernel(item: Item, a): @pytest.mark.parametrize( - "kernel", [private_array_kernel, private_2d_array_kernel] + "kernel", + [ + private_array_kernel, + private_array_kernel_fill_true, + private_array_kernel_fill_false, + private_2d_array_kernel, + ], ) @pytest.mark.parametrize( "call_kernel, decorator",