diff --git a/numba_dpex/core/typing/dpnpdecl.py b/numba_dpex/core/typing/dpnpdecl.py index 6856f8ea1f..5cf7fd78d9 100644 --- a/numba_dpex/core/typing/dpnpdecl.py +++ b/numba_dpex/core/typing/dpnpdecl.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 +import logging + import dpnp import numpy as np from numba.core import types @@ -34,14 +36,33 @@ class DpnpRulesArrayOperator(NumpyRulesArrayOperator): @property def ufunc(self): try: - op = getattr(dpnp, self._op_map[self.key]) + dpnpop = getattr(dpnp, self._op_map[self.key]) npop = getattr(np, self._op_map[self.key]) - op.nin = npop.nin - op.nout = npop.nout - op.nargs = npop.nargs - op.types = npop.types - op.is_dpnp_ufunc = True - return op + if not hasattr(dpnpop, "nin"): + dpnpop.nin = npop.nin + if not hasattr(dpnpop, "nout"): + dpnpop.nout = npop.nout + if not hasattr(dpnpop, "nargs"): + dpnpop.nargs = dpnpop.nin + dpnpop.nout + + # Check if the dpnp operation has a `types` attribute and if an + # AttributeError gets raised then "monkey patch" the attribute from + # numpy. If the attribute lookup raised a ValueError, it indicates + # that dpnp could not be resolve the supported types for the + # operation. Dpnp will fail to resolve the `types` if no SYCL + # devices are available on the system. For such a scenario, we print + # a user-level warning. + try: + dpnpop.types + except ValueError: + logging.exception( + f"The types attribute for the {dpnpop} fuction could not " + "be determined." + ) + except AttributeError: + dpnpop.types = npop.types + dpnpop.is_dpnp_ufunc = True + return dpnpop except: pass diff --git a/numba_dpex/dpnp_iface/dpnp_ufunc_db.py b/numba_dpex/dpnp_iface/dpnp_ufunc_db.py index def71adb17..3cd1945595 100644 --- a/numba_dpex/dpnp_iface/dpnp_ufunc_db.py +++ b/numba_dpex/dpnp_iface/dpnp_ufunc_db.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import copy +import logging import dpnp import numpy as np @@ -56,6 +57,7 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db): # variable is passed by value from numba.np.ufunc_db import _ufunc_db + failed_dpnpop_types_lst = [] for ufuncop in dpnpdecl.supported_ufuncs: if ufuncop == "erf": op = getattr(dpnp, "erf") @@ -72,20 +74,50 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db): "d->d": mathimpl.lower_ocl_impl[("erf", (_unary_d_d))], } else: - op = getattr(dpnp, ufuncop) + dpnpop = getattr(dpnp, ufuncop) npop = getattr(np, ufuncop) - op.nin = npop.nin - op.nout = npop.nout - op.nargs = npop.nargs - op.types = npop.types - op.is_dpnp_ufunc = True + if not hasattr(dpnpop, "nin"): + dpnpop.nin = npop.nin + if not hasattr(dpnpop, "nout"): + dpnpop.nout = npop.nout + if not hasattr(dpnpop, "nargs"): + dpnpop.nargs = dpnpop.nin + dpnpop.nout + + # Check if the dpnp operation has a `types` attribute and if an + # AttributeError gets raised then "monkey patch" the attribute from + # numpy. If the attribute lookup raised a ValueError, it indicates + # that dpnp could not be resolve the supported types for the + # operation. Dpnp will fail to resolve the `types` if no SYCL + # devices are available on the system. For such a scenario, we log + # dpnp operations for which the ValueError happened and print them + # as a user-level warning. It is done this way so that the failure + # to load the dpnpdecl registry due to the ValueError does not + # impede a user from importing numba-dpex. + try: + dpnpop.types + except ValueError: + failed_dpnpop_types_lst.append(ufuncop) + except AttributeError: + dpnpop.types = npop.types + + dpnpop.is_dpnp_ufunc = True cp = copy.copy(_ufunc_db[npop]) - ufunc_db.update({op: cp}) - for key in list(ufunc_db[op].keys()): + ufunc_db.update({dpnpop: cp}) + for key in list(ufunc_db[dpnpop].keys()): if ( "FF->" in key or "DD->" in key or "F->" in key or "D->" in key ): - ufunc_db[op].pop(key) + ufunc_db[dpnpop].pop(key) + + if failed_dpnpop_types_lst: + try: + getattr(dpnp, failed_dpnpop_types_lst[0]).types + except ValueError: + ops = " ".join(failed_dpnpop_types_lst) + logging.exception( + "The types attribute for the following dpnp ops could not be " + f"determined: {ops}" + )