Skip to content

Commit

Permalink
ENH: nper: Perform broadcast rework
Browse files Browse the repository at this point in the history
This commit rewrites the ``nper`` function to mimic broadcasting.
As this requires writing manual for loops numba is used to speed
up the calculations. Further a fuzz test is added.
  • Loading branch information
Kai-Striega committed Mar 23, 2024
1 parent 6b8b1dc commit 0b83101
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 22 deletions.
85 changes: 63 additions & 22 deletions numpy_financial/_financial.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,42 @@ def pmt(rate, nper, pv, fv=0, when='end'):
return -(fv + pv * temp) / fact


@nb.njit
def _nper_inner_loop(rate, pmt, pv, fv, when):
if rate == 0.0:
if pmt == 0.0:
# If no repayments are made the payments will go on forever
return np.inf
else:
return -(fv + pv) / pmt
else:
# We know that rate != 0.0, so we are sure this won't cause a ZeroDivisionError
z = pmt * (1.0 + rate * when) / rate
try:
numer = np.log((-fv + z) / (pv + z))
denom = np.log(1.0 + rate)
return numer / denom
except Exception: # As of March 24, numba only supports generic exceptions
# TODO: There are several ``ZeroDivisionError``s here.
# We need to figure out exactly what's causing these
# and return financially sensible values.
return np.nan


@nb.njit
def _nper_native(rates, pmts, pvs, fvs, whens, out):
for rate in range(rates.shape[0]):
for pmt in range(pmts.shape[0]):
for pv in range(pvs.shape[0]):
for fv in range(fvs.shape[0]):
for when in range(whens.shape[0]):
out[rate, pmt, pv, fv, when] = _nper_inner_loop(
rates[rate], pmts[pmt], pvs[pv], fvs[fv], whens[when]
)




def nper(rate, pmt, pv, fv=0, when='end'):
"""Compute the number of periodic payments.
Expand Down Expand Up @@ -297,43 +333,48 @@ def nper(rate, pmt, pv, fv=0, when='end'):
If you only had $150/month to pay towards the loan, how long would it take
to pay-off a loan of $8,000 at 7% annual interest?
>>> print(np.round(npf.nper(0.07/12, -150, 8000), 5))
>>> round(npf.nper(0.07/12, -150, 8000), 5)
64.07335
So, over 64 months would be required to pay off the loan.
The same analysis could be done with several different interest rates
and/or payments and/or total amounts to produce an entire table.
>>> npf.nper(*(np.ogrid[0.07/12: 0.08/12: 0.01/12,
... -150 : -99 : 50 ,
... 8000 : 9001 : 1000]))
array([[[ 64.07334877, 74.06368256],
[108.07548412, 127.99022654]],
>>> rates = [0.05, 0.06, 0.07]
>>> payments = [100, 200, 300]
>>> amounts = [7_000, 8_000, 9_000]
>>> npf.nper(rates, payments, amounts).round(3)
array([[[-30.827, -32.987, -34.94 ],
[-20.734, -22.517, -24.158],
[-15.847, -17.366, -18.78 ]],
<BLANKLINE>
[[-28.294, -30.168, -31.857],
[-19.417, -21.002, -22.453],
[-15.025, -16.398, -17.67 ]],
<BLANKLINE>
[[ 66.12443902, 76.87897353],
[114.70165583, 137.90124779]]])
[[-26.234, -27.891, -29.381],
[-18.303, -19.731, -21.034],
[-14.311, -15.566, -16.722]]])
"""
when = _convert_when(when)
rate, pmt, pv, fv, when = np.broadcast_arrays(rate, pmt, pv, fv, when)
nper_array = np.empty_like(rate, dtype=np.float64)

zero = rate == 0
nonzero = ~zero
rate_inner = np.atleast_1d(rate)
pmt_inner = np.atleast_1d(pmt)
pv_inner = np.atleast_1d(pv)
fv_inner = np.atleast_1d(fv)
when_inner = np.atleast_1d(when)

with np.errstate(divide='ignore'):
# Infinite numbers of payments are okay, so ignore the
# potential divide by zero.
nper_array[zero] = -(fv[zero] + pv[zero]) / pmt[zero]
# TODO: Validate ``*_inner`` array shapes

nonzero_rate = rate[nonzero]
z = pmt[nonzero] * (1 + nonzero_rate * when[nonzero]) / nonzero_rate
nper_array[nonzero] = (
np.log((-fv[nonzero] + z) / (pv[nonzero] + z))
/ np.log(1 + nonzero_rate)
out_shape = _get_output_array_shape(
rate_inner, pmt_inner, pv_inner, fv_inner, when_inner
)
out = np.empty(out_shape)
_nper_native(rate_inner, pmt_inner, pv_inner, fv_inner, when_inner, out)

return nper_array
return _ufunc_like(out)


def _value_like(arr, value):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_financial.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def uint_dtype():
shape=npst.array_shapes(min_dims=0, max_dims=1, min_side=0, max_side=5),
)

when_strategy = st.sampled_from(
['end', 'begin', 'e', 'b', 0, 1, 'beginning', 'start', 'finish']
)


def assert_decimal_close(actual, expected, tol=Decimal("1e-7")):
# Check if both actual and expected are iterable (like arrays)
Expand Down Expand Up @@ -426,6 +430,17 @@ def test_broadcast(self):
npf.nper(0.075, -2000, 0, 100000.0, [0, 1]), [21.5449442, 20.76156441], 4
)

@given(
rates=short_scalar_array,
payments=short_scalar_array,
present_values=short_scalar_array,
future_values=short_scalar_array,
whens=when_strategy,
)
@settings(deadline=None) # ignore jit compilation of a function
def test_fuzz(self, rates, payments, present_values, future_values, whens):
npf.nper(rates, payments, present_values, future_values, whens)


class TestPpmt:
def test_float(self):
Expand Down

0 comments on commit 0b83101

Please sign in to comment.