Skip to content

Commit

Permalink
Merge pull request IntelPython#1269 from IntelPython/feature/set_over…
Browse files Browse the repository at this point in the history
…aloads_target_specific

Make overloads and intrinsic target specific
  • Loading branch information
Diptorup Deb authored Jan 5, 2024
2 parents 46cfa8c + 011675c commit 8c3cbed
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 34 deletions.
11 changes: 7 additions & 4 deletions numba_dpex/core/kernel_interface/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from numba.core.datamodel import default_manager
from numba.extending import intrinsic, overload

# can't import name because of the circular import
DPEX_TARGET_NAME = "dpex"


class Range(tuple):
"""A data structure to encapsulate a single kernel launch parameter.
Expand Down Expand Up @@ -231,7 +234,7 @@ def __eq__(self, other):
return False


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def _intrin_range_alloc(typingctx, ty_dim0, ty_dim1, ty_dim2, ty_range):
ty_retty = ty_range.instance_type
sig = ty_retty(
Expand Down Expand Up @@ -268,7 +271,7 @@ def codegen(context, builder, sig, args):
return sig, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def _intrin_ndrange_alloc(
typingctx, ty_global_range, ty_local_range, ty_ndrange
):
Expand Down Expand Up @@ -318,7 +321,7 @@ def codegen(context, builder, sig, args):
return sig, codegen


@overload(Range)
@overload(Range, target=DPEX_TARGET_NAME)
def _ol_range_init(dim0, dim1=None, dim2=None):
"""Numba overload of the Range constructor to make it usable inside an
njit and dpjit decorated function.
Expand Down Expand Up @@ -353,7 +356,7 @@ def impl(dim0, dim1=None, dim2=None):
return impl


@overload(NdRange)
@overload(NdRange, target=DPEX_TARGET_NAME)
def _ol_ndrange_init(global_range, local_range):
"""Numba overload of the NdRange constructor to make it usable inside an
njit and dpjit decorated function.
Expand Down
2 changes: 0 additions & 2 deletions numba_dpex/core/targets/dpjit_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from functools import cached_property

from numba.core import utils
from numba.core.codegen import JITCPUCodegen
from numba.core.compiler_lock import global_compiler_lock
from numba.core.cpu import CPUContext
from numba.core.imputils import Registry
Expand Down
9 changes: 5 additions & 4 deletions numba_dpex/dpctl_iface/_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl
from numba_dpex.core import types as dpex_types
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME
from numba_dpex.dpctl_iface.wrappers import wrap_event_reference


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def sycl_event_create(
ty_context,
):
Expand All @@ -38,7 +39,7 @@ def codegen(context, builder: IRBuilder, sig, args: list):
return sig, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def sycl_event_wait(typingctx, ty_event: dpex_types.DpctlSyclEvent):
sig = types.void(dpex_types.DpctlSyclEvent())

Expand All @@ -55,15 +56,15 @@ def codegen(context, builder, signature, args):
return sig, codegen


@overload(dpctl.SyclEvent)
@overload(dpctl.SyclEvent, target=DPEX_TARGET_NAME)
def ol_dpctl_sycl_event_create():
"""Implementation of an overload to support dpctl.SyclEvent() inside
a dpjit function.
"""
return lambda: sycl_event_create()


@overload_method(dpex_types.DpctlSyclEvent, "wait")
@overload_method(dpex_types.DpctlSyclEvent, "wait", target=DPEX_TARGET_NAME)
def ol_dpctl_sycl_event_wait(
event,
):
Expand Down
25 changes: 14 additions & 11 deletions numba_dpex/dpnp_iface/_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl

# can't import name because of the circular import
DPEX_TARGET_NAME = "dpex"

_QueueRefPayload = namedtuple(
"QueueRefPayload", ["queue_ref", "py_dpctl_sycl_queue_addr", "pyapi"]
)
Expand Down Expand Up @@ -305,7 +308,7 @@ def _empty_nd_impl(context, builder, arrtype, shapes, queue_ref):
return ary


@overload_classmethod(DpnpNdArray, "_usm_allocate")
@overload_classmethod(DpnpNdArray, "_usm_allocate", target=DPEX_TARGET_NAME)
def _ol_array_allocate(cls, allocsize, usm_type, queue):
"""Implements an allocator for dpnp.ndarrays."""

Expand All @@ -326,7 +329,7 @@ def _call_usm_allocator(arrtype, size, usm_type, queue):
numba_config.DISABLE_PERFORMANCE_WARNINGS = 1


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def intrin_usm_alloc(typingctx, allocsize, usm_type, queue):
"""Intrinsic to call into the allocator for Array"""

Expand Down Expand Up @@ -425,7 +428,7 @@ def fill_arrayobj(context, builder, ary, arrtype, queue_ref, fill_value):
return ary, arrtype


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def impl_dpnp_empty(
ty_context,
ty_shape,
Expand Down Expand Up @@ -495,7 +498,7 @@ def codegen(context, builder, sig, args):
return sig, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def impl_dpnp_zeros(
ty_context,
ty_shape,
Expand Down Expand Up @@ -572,7 +575,7 @@ def codegen(context, builder, sig, args):
return sig, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def impl_dpnp_ones(
ty_context,
ty_shape,
Expand Down Expand Up @@ -650,7 +653,7 @@ def codegen(context, builder, sig, args):
return sig, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def impl_dpnp_full(
ty_context,
ty_shape,
Expand Down Expand Up @@ -734,7 +737,7 @@ def codegen(context, builder, sig, args):
return signature, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def impl_dpnp_empty_like(
ty_context,
ty_x1,
Expand Down Expand Up @@ -813,7 +816,7 @@ def codegen(context, builder, sig, args):
return sig, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def impl_dpnp_zeros_like(
ty_context,
ty_x1,
Expand Down Expand Up @@ -901,7 +904,7 @@ def codegen(context, builder, sig, args):
return sig, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def impl_dpnp_ones_like(
ty_context,
ty_x1,
Expand Down Expand Up @@ -988,7 +991,7 @@ def codegen(context, builder, sig, args):
return sig, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def impl_dpnp_full_like(
ty_context,
ty_x1,
Expand Down Expand Up @@ -1079,7 +1082,7 @@ def codegen(context, builder, sig, args):
return signature, codegen


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def ol_dpnp_nd_array_sycl_queue(
ty_context,
ty_dpnp_nd_array: DpnpNdArray,
Expand Down
22 changes: 13 additions & 9 deletions numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
ol_dpnp_nd_array_sycl_queue,
)

# can't import name because of the circular import
DPEX_TARGET_NAME = "dpex"

# =========================================================================
# Helps to parse dpnp constructor arguments
# =========================================================================
Expand Down Expand Up @@ -164,7 +167,7 @@ def _parse_device_filter_string(device):
# =========================================================================


@overload(dpnp.empty, prefer_literal=True)
@overload(dpnp.empty, prefer_literal=True, target=DPEX_TARGET_NAME)
def ol_dpnp_empty(
shape,
dtype=None,
Expand Down Expand Up @@ -261,7 +264,7 @@ def impl(
raise errors.TypingError("Could not infer the rank of the ndarray.")


@overload(dpnp.zeros, prefer_literal=True)
@overload(dpnp.zeros, prefer_literal=True, target=DPEX_TARGET_NAME)
def ol_dpnp_zeros(
shape,
dtype=None,
Expand Down Expand Up @@ -355,7 +358,7 @@ def impl(
raise errors.TypingError("Could not infer the rank of the ndarray.")


@overload(dpnp.ones, prefer_literal=True)
@overload(dpnp.ones, prefer_literal=True, target=DPEX_TARGET_NAME)
def ol_dpnp_ones(
shape,
dtype=None,
Expand Down Expand Up @@ -449,7 +452,7 @@ def impl(
raise errors.TypingError("Could not infer the rank of the ndarray.")


@overload(dpnp.full, prefer_literal=True)
@overload(dpnp.full, prefer_literal=True, target=DPEX_TARGET_NAME)
def ol_dpnp_full(
shape,
fill_value,
Expand Down Expand Up @@ -558,7 +561,7 @@ def impl(
raise errors.TypingError("Could not infer the rank of the ndarray.")


@overload(dpnp.empty_like, prefer_literal=True)
@overload(dpnp.empty_like, prefer_literal=True, target=DPEX_TARGET_NAME)
def ol_dpnp_empty_like(
x1,
dtype=None,
Expand Down Expand Up @@ -683,7 +686,7 @@ def impl(
)


@overload(dpnp.zeros_like, prefer_literal=True)
@overload(dpnp.zeros_like, prefer_literal=True, target=DPEX_TARGET_NAME)
def ol_dpnp_zeros_like(
x1,
dtype=None,
Expand Down Expand Up @@ -807,7 +810,7 @@ def impl(
)


@overload(dpnp.ones_like, prefer_literal=True)
@overload(dpnp.ones_like, prefer_literal=True, target=DPEX_TARGET_NAME)
def ol_dpnp_ones_like(
x1,
dtype=None,
Expand Down Expand Up @@ -932,7 +935,7 @@ def impl(
)


@overload(dpnp.full_like, prefer_literal=True)
@overload(dpnp.full_like, prefer_literal=True, target=DPEX_TARGET_NAME)
def ol_dpnp_full_like(
x1,
fill_value,
Expand Down Expand Up @@ -1062,6 +1065,7 @@ def impl(
)


# TODO: target specific
@lower_builtin(operator.getitem, DpnpNdArray, types.Integer)
@lower_builtin(operator.getitem, DpnpNdArray, types.SliceType)
def getitem_arraynd_intp(context, builder, sig, args):
Expand All @@ -1088,7 +1092,7 @@ def getitem_arraynd_intp(context, builder, sig, args):
return ret


@overload_attribute(DpnpNdArray, "sycl_queue")
@overload_attribute(DpnpNdArray, "sycl_queue", target=DPEX_TARGET_NAME)
def dpnp_nd_array_sycl_queue(arr):
"""Returns :class:`dpctl.SyclQueue` object associated with USM data.
Expand Down
5 changes: 3 additions & 2 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from numba.extending import intrinsic

from numba_dpex import dpjit
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
from numba_dpex.core.types import DpctlSyclEvent, NdRangeType, RangeType
from numba_dpex.core.utils import kernel_launcher as kl
Expand Down Expand Up @@ -49,7 +50,7 @@ def wrap_event_reference_tuple(ctx, builder, event1, event2):
return tup


@intrinsic(target="cpu")
@intrinsic(target=DPEX_TARGET_NAME)
def _submit_kernel_async(
typingctx,
ty_kernel_fn: Dispatcher,
Expand All @@ -68,7 +69,7 @@ def _submit_kernel_async(
)


@intrinsic(target="cpu")
@intrinsic(target=DPEX_TARGET_NAME)
def _submit_kernel_sync(
typingctx,
ty_kernel_fn: Dispatcher,
Expand Down
3 changes: 2 additions & 1 deletion numba_dpex/experimental/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

from numba_dpex import dpjit
from numba_dpex.core.runtime.context import DpexRTContext
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME


@intrinsic(target="cpu")
@intrinsic(target=DPEX_TARGET_NAME)
def _kernel_cache_size(
typingctx, # pylint: disable=W0613
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from numba.extending import intrinsic

from numba_dpex import dpjit
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME


@intrinsic
@intrinsic(target=DPEX_TARGET_NAME)
def are_queues_equal(typingctx, ty_queue1, ty_queue2):
"""Calls dpctl's libsyclinterface's DPCTLQueue_AreEq to see if two
dpctl.SyclQueue objects point to the same sycl queue.
Expand Down
27 changes: 27 additions & 0 deletions numba_dpex/tests/dpjit_tests/dpnp/test_target_specific_overload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Tests if dpnp dpex specific overloads are not available at numba njit.
"""

import dpnp
import pytest
from numba import njit
from numba.core import errors

from numba_dpex import dpjit


@pytest.mark.parametrize("func", [dpnp.empty, dpnp.ones, dpnp.zeros])
def test_dpnp_dpex_target(func):
def dpnp_func():
func(10)

dpnp_func_njit = njit(dpnp_func)
dpnp_func_dpjit = dpjit(dpnp_func)

dpnp_func_dpjit()
with pytest.raises(errors.TypingError):
dpnp_func_njit()
Loading

0 comments on commit 8c3cbed

Please sign in to comment.