Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: Update variables a and b to names consistent with comment docume… #60372

Closed
wants to merge 7 commits into from
76 changes: 41 additions & 35 deletions pandas/core/computation/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,23 @@ def set_numexpr_threads(n=None) -> None:
ne.set_num_threads(n)


def _evaluate_standard(op, op_str, a, b):
def _evaluate_standard(op, op_str, left_op, right_op):
"""
Standard evaluation.
"""
if _TEST_MODE:
_store_test_result(False)
return op(a, b)
return op(left_op, right_op)


def _can_use_numexpr(op, op_str, a, b, dtype_check) -> bool:
def _can_use_numexpr(op, op_str, left_op, right_op, dtype_check) -> bool:
"""return a boolean if we WILL be using numexpr"""
if op_str is not None:
# required min elements (otherwise we are adding overhead)
if a.size > _MIN_ELEMENTS:
if left_op.size > _MIN_ELEMENTS:
# check for dtype compatibility
dtypes: set[str] = set()
for o in [a, b]:
for o in [left_op, right_op]:
# ndarray and Series Case
if hasattr(o, "dtype"):
dtypes |= {o.dtype.name}
Expand All @@ -93,43 +93,43 @@ def _can_use_numexpr(op, op_str, a, b, dtype_check) -> bool:
return False


def _evaluate_numexpr(op, op_str, a, b):
def _evaluate_numexpr(op, op_str, left_op, right_op):
result = None

if _can_use_numexpr(op, op_str, a, b, "evaluate"):
if _can_use_numexpr(op, op_str, left_op, right_op, "evaluate"):
is_reversed = op.__name__.strip("_").startswith("r")
if is_reversed:
# we were originally called by a reversed op method
a, b = b, a
left_op, right_op = right_op, left_op

a_value = a
b_value = b
left_value = left_op
right_value = right_op

try:
result = ne.evaluate(
f"a_value {op_str} b_value",
local_dict={"a_value": a_value, "b_value": b_value},
f"left_value {op_str} right_value",
local_dict={"left_value": left_value, "right_value": right_value},
casting="safe",
)
except TypeError:
# numexpr raises eg for array ** array with integers
# (https://github.com/pydata/numexpr/issues/379)
pass
except NotImplementedError:
if _bool_arith_fallback(op_str, a, b):
if _bool_arith_fallback(op_str, left_op, right_op):
pass
else:
raise

if is_reversed:
# reverse order to original for fallback
a, b = b, a
left_op, right_op = right_op, left_op

if _TEST_MODE:
_store_test_result(result is not None)

if result is None:
result = _evaluate_standard(op, op_str, a, b)
result = _evaluate_standard(op, op_str, left_op, right_op)

return result

Expand Down Expand Up @@ -170,24 +170,28 @@ def _evaluate_numexpr(op, op_str, a, b):
}


def _where_standard(cond, a, b):
def _where_standard(cond, left_op, right_op):
# Caller is responsible for extracting ndarray if necessary
return np.where(cond, a, b)
return np.where(cond, left_op, right_op)


def _where_numexpr(cond, a, b):
def _where_numexpr(cond, left_op, right_op):
# Caller is responsible for extracting ndarray if necessary
result = None

if _can_use_numexpr(None, "where", a, b, "where"):
if _can_use_numexpr(None, "where", left_op, right_op, "where"):
result = ne.evaluate(
"where(cond_value, a_value, b_value)",
local_dict={"cond_value": cond, "a_value": a, "b_value": b},
"where(cond_value, left_value, right_value)",
local_dict={
"cond_value": cond,
"left_value": left_op,
"right_value": right_op
},
casting="safe",
)

if result is None:
result = _where_standard(cond, a, b)
result = _where_standard(cond, left_op, right_op)

return result

Expand All @@ -206,13 +210,13 @@ def _has_bool_dtype(x):
_BOOL_OP_UNSUPPORTED = {"+": "|", "*": "&", "-": "^"}


def _bool_arith_fallback(op_str, a, b) -> bool:
def _bool_arith_fallback(op_str, left_op, right_op) -> bool:
"""
Check if we should fallback to the python `_evaluate_standard` in case
of an unsupported operation by numexpr, which is the case for some
boolean ops.
"""
if _has_bool_dtype(a) and _has_bool_dtype(b):
if _has_bool_dtype(left_op) and _has_bool_dtype(right_op):
if op_str in _BOOL_OP_UNSUPPORTED:
warnings.warn(
f"evaluating in Python space because the {op_str!r} "
Expand All @@ -224,40 +228,42 @@ def _bool_arith_fallback(op_str, a, b) -> bool:
return False


def evaluate(op, a, b, use_numexpr: bool = True):
def evaluate(op, left_op, right_op, use_numexpr: bool = True):
"""
Evaluate and return the expression of the op on a and b.
Evaluate and return the expression of the op on left_op and right_op.

Parameters
----------
op : the actual operand
a : left operand
b : right operand
left_op : left operand
right_op : right operand
use_numexpr : bool, default True
Whether to try to use numexpr.
"""
op_str = _op_str_mapping[op]
if op_str is not None:
if use_numexpr:
# error: "None" not callable
return _evaluate(op, op_str, a, b) # type: ignore[misc]
return _evaluate_standard(op, op_str, a, b)
return _evaluate(op, op_str, left_op, right_op) # type: ignore[misc]
return _evaluate_standard(op, op_str, left_op, right_op)


def where(cond, a, b, use_numexpr: bool = True):
def where(cond, left_op, right_op, use_numexpr: bool = True):
"""
Evaluate the where condition cond on a and b.
Evaluate the where condition cond on left_op and right_op.

Parameters
----------
cond : np.ndarray[bool]
a : return if cond is True
b : return if cond is False
left_op : return if cond is True
right_op : return if cond is False
use_numexpr : bool, default True
Whether to try to use numexpr.
"""
assert _where is not None
return _where(cond, a, b) if use_numexpr else _where_standard(cond, a, b)
return (_where(cond, left_op, right_op)
if use_numexpr
else _where_standard(cond, left_op, right_op))


def set_test_mode(v: bool = True) -> None:
Expand Down
Loading