Skip to content

Commit

Permalink
Merge pull request IntelPython#1434 from IntelPython/fix/dpnp_nin_issue
Browse files Browse the repository at this point in the history
Fix for dpex failure caused by addition of nin, nout and types
  • Loading branch information
Diptorup Deb authored Apr 17, 2024
2 parents 1336f76 + c793313 commit 4708ac7
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 16 deletions.
35 changes: 28 additions & 7 deletions numba_dpex/core/typing/dpnpdecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

import logging

import dpnp
import numpy as np
from numba.core import types
Expand Down Expand Up @@ -34,14 +36,33 @@ class DpnpRulesArrayOperator(NumpyRulesArrayOperator):
@property
def ufunc(self):
try:
op = getattr(dpnp, self._op_map[self.key])
dpnpop = getattr(dpnp, self._op_map[self.key])
npop = getattr(np, self._op_map[self.key])
op.nin = npop.nin
op.nout = npop.nout
op.nargs = npop.nargs
op.types = npop.types
op.is_dpnp_ufunc = True
return op
if not hasattr(dpnpop, "nin"):
dpnpop.nin = npop.nin
if not hasattr(dpnpop, "nout"):
dpnpop.nout = npop.nout
if not hasattr(dpnpop, "nargs"):
dpnpop.nargs = dpnpop.nin + dpnpop.nout

# Check if the dpnp operation has a `types` attribute and if an
# AttributeError gets raised then "monkey patch" the attribute from
# numpy. If the attribute lookup raised a ValueError, it indicates
# that dpnp could not be resolve the supported types for the
# operation. Dpnp will fail to resolve the `types` if no SYCL
# devices are available on the system. For such a scenario, we print
# a user-level warning.
try:
dpnpop.types
except ValueError:
logging.exception(
f"The types attribute for the {dpnpop} fuction could not "
"be determined."
)
except AttributeError:
dpnpop.types = npop.types
dpnpop.is_dpnp_ufunc = True
return dpnpop
except:
pass

Expand Down
50 changes: 41 additions & 9 deletions numba_dpex/dpnp_iface/dpnp_ufunc_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import copy
import logging

import dpnp
import numpy as np
Expand Down Expand Up @@ -56,6 +57,7 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
# variable is passed by value
from numba.np.ufunc_db import _ufunc_db

failed_dpnpop_types_lst = []
for ufuncop in dpnpdecl.supported_ufuncs:
if ufuncop == "erf":
op = getattr(dpnp, "erf")
Expand All @@ -72,20 +74,50 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
"d->d": mathimpl.lower_ocl_impl[("erf", (_unary_d_d))],
}
else:
op = getattr(dpnp, ufuncop)
dpnpop = getattr(dpnp, ufuncop)
npop = getattr(np, ufuncop)
op.nin = npop.nin
op.nout = npop.nout
op.nargs = npop.nargs
op.types = npop.types
op.is_dpnp_ufunc = True
if not hasattr(dpnpop, "nin"):
dpnpop.nin = npop.nin
if not hasattr(dpnpop, "nout"):
dpnpop.nout = npop.nout
if not hasattr(dpnpop, "nargs"):
dpnpop.nargs = dpnpop.nin + dpnpop.nout

# Check if the dpnp operation has a `types` attribute and if an
# AttributeError gets raised then "monkey patch" the attribute from
# numpy. If the attribute lookup raised a ValueError, it indicates
# that dpnp could not be resolve the supported types for the
# operation. Dpnp will fail to resolve the `types` if no SYCL
# devices are available on the system. For such a scenario, we log
# dpnp operations for which the ValueError happened and print them
# as a user-level warning. It is done this way so that the failure
# to load the dpnpdecl registry due to the ValueError does not
# impede a user from importing numba-dpex.
try:
dpnpop.types
except ValueError:
failed_dpnpop_types_lst.append(ufuncop)
except AttributeError:
dpnpop.types = npop.types

dpnpop.is_dpnp_ufunc = True
cp = copy.copy(_ufunc_db[npop])
ufunc_db.update({op: cp})
for key in list(ufunc_db[op].keys()):
ufunc_db.update({dpnpop: cp})
for key in list(ufunc_db[dpnpop].keys()):
if (
"FF->" in key
or "DD->" in key
or "F->" in key
or "D->" in key
):
ufunc_db[op].pop(key)
ufunc_db[dpnpop].pop(key)

if failed_dpnpop_types_lst:
try:
getattr(dpnp, failed_dpnpop_types_lst[0]).types
except ValueError:
ops = " ".join(failed_dpnpop_types_lst)
logging.exception(
"The types attribute for the following dpnp ops could not be "
f"determined: {ops}"
)

0 comments on commit 4708ac7

Please sign in to comment.