Skip to content

Commit

Permalink
Add dpex target intrinsic test
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Jan 5, 2024
1 parent 4242584 commit 011675c
Showing 1 changed file with 42 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
"""

import pytest
from numba import njit
from numba import njit, types
from numba.core import errors
from numba.extending import overload
from numba.extending import intrinsic, overload

from numba_dpex import dpjit
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME
Expand All @@ -25,10 +25,38 @@ def ol_foo():
return lambda: 1


@intrinsic(target=DPEX_TARGET_NAME)
def intrinsic_foo(
ty_context,
):
"""A numba "intrinsic" function to inject dpctl.SyclEvent constructor code.
Args:
ty_context (numba.core.typing.context.Context): The typing context
for the codegen.
Returns:
tuple(numba.core.typing.templates.Signature, function): A tuple of
numba function signature type and a function object.
"""

sig = types.int32(types.void)

def codegen(context, builder, sig, args: list):
return context.get_constant(types.int32, 1)

return sig, codegen


def bar():
return foo()


def intrinsic_bar():
res = intrinsic_foo()
return res


def test_dpex_overload_from_njit():
bar_njit = njit(bar)

Expand All @@ -39,3 +67,15 @@ def test_dpex_overload_from_njit():
def test_dpex_overload_from_dpjit():
bar_dpjit = dpjit(bar)
bar_dpjit()


def test_dpex_intrinsic_from_njit():
bar_njit = njit(intrinsic_bar)

with pytest.raises(errors.TypingError):
bar_njit()


def test_dpex_intrinsic_from_dpjit():
bar_dpjit = dpjit(intrinsic_bar)
bar_dpjit()

0 comments on commit 011675c

Please sign in to comment.