Skip to content

Commit

Permalink
Add dpex target tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Jan 5, 2024
1 parent a186cf0 commit 4242584
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
27 changes: 27 additions & 0 deletions numba_dpex/tests/dpjit_tests/dpnp/test_target_specific_overload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Tests if dpnp dpex specific overloads are not available at numba njit.
"""

import dpnp
import pytest
from numba import njit
from numba.core import errors

from numba_dpex import dpjit


@pytest.mark.parametrize("func", [dpnp.empty, dpnp.ones, dpnp.zeros])
def test_dpnp_dpex_target(func):
def dpnp_func():
func(10)

dpnp_func_njit = njit(dpnp_func)
dpnp_func_dpjit = dpjit(dpnp_func)

dpnp_func_dpjit()
with pytest.raises(errors.TypingError):
dpnp_func_njit()
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Tests if dpex target overloads are not available at numba.njit and only
available at numba_dpex.dpjit.
"""

import pytest
from numba import njit
from numba.core import errors
from numba.extending import overload

from numba_dpex import dpjit
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME


def foo():
return 1


@overload(foo, target=DPEX_TARGET_NAME)
def ol_foo():
return lambda: 1


def bar():
return foo()


def test_dpex_overload_from_njit():
bar_njit = njit(bar)

with pytest.raises(errors.TypingError):
bar_njit()


def test_dpex_overload_from_dpjit():
bar_dpjit = dpjit(bar)
bar_dpjit()

0 comments on commit 4242584

Please sign in to comment.