-
-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Proof of Concept: Move NPV to Numba #91
Conversation
This seems to be more straight forward to maintain.
@nb.jit(forceobj=True) # Need ``forceobj`` to support decimal.Decimal | ||
def _npv_decimal(rates, cashflows, result): | ||
r"""Version of the ``npv`` function supporting ``decimal.Decimal`` types | ||
|
||
Warnings | ||
-------- | ||
For internal use only, note that this function performs no error checking. | ||
""" | ||
for i in range(rates.shape[0]): | ||
for j in range(cashflows.shape[0]): | ||
acc = Decimal("0.0") | ||
for t in range(cashflows.shape[1]): | ||
acc += cashflows[j, t] / ((Decimal("1.0") + rates[i]) ** t) | ||
result[i, j] = acc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is essentially a copy of _npv
with the constants replaced with decimal.Decimal
s. I've tried doing something along the lines of dtype = Decimal if arr.dtype = np.dtype("O") else arr.dtype
but this wouldn't compile in nopython
mode. I'm trying to think of a neater way to implement this but haven't been able to think of one yet
r"""Native version of the ``npv`` function. | ||
|
||
Warnings | ||
-------- | ||
For internal use only, note that this function performs no error checking. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documentation still needs to be improved; this is the workhorse powering the npv
function.
@@ -825,6 +827,40 @@ def irr(values, *, guess=None, tol=1e-12, maxiter=100, raise_exceptions=False): | |||
return np.nan | |||
|
|||
|
|||
@nb.njit(parallel=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if it's worth parallelising this function.
|
||
r = np.atleast_1d(rate) | ||
v = np.atleast_2d(values) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Add checking of array shapes if they are compatible.
actual_npvs = npf.npv(0.05, cashflows) | ||
assert_allclose(actual_npvs, expected_npvs) | ||
|
||
@pytest.mark.parametrize("dtype", [Decimal, float]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: check with other dtypes e.g. float32
for i, r in enumerate(rates): | ||
for j, cf in enumerate(cashflows): | ||
expected[i, j] = npf.npv(r, cf).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we're using a for loop to calculate the equivalent to the "broadcasting" operation. This is slow, however I am more confident that it is correct.
Closing as Numba hasn't released Python 3.12 support yet. I'll also be exploring the Cython approach as an alternative approach. |
This PR moves the Net Present Value function to Numba
It is still a WIP, I'm not sure if the design decisions I've made are correct. Feedback is welcome.
I've tried using
numba.guvectorize
to make gufuncs out of the NumPy-Financial functions. This didn't work well (because we're not using "proper" gufuncs. Instead I've written them as for loops that are spead up with numba'sjit
/njit
functions.This leads to some duplication as we need both a
jit
andnjit
decorator to supportdecimal.Decimal
(which is a PyObject).Note that Numba does not support Python3.12 yet, this is expected to be out soon (in the 0.59 release) see the tracking issue for more details