From 04107813b8e92c737e06d60046a13f76e22ef6f4 Mon Sep 17 00:00:00 2001 From: Kai Striega Date: Sat, 23 Mar 2024 13:23:17 +1100 Subject: [PATCH] ENH: nper: Raise an error if arrays are of invalid shape --- numpy_financial/_financial.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/numpy_financial/_financial.py b/numpy_financial/_financial.py index cd6bf4e..08e2ad6 100644 --- a/numpy_financial/_financial.py +++ b/numpy_financial/_financial.py @@ -295,8 +295,6 @@ def _nper_native(rates, pmts, pvs, fvs, whens, out): ) - - def nper(rate, pmt, pv, fv=0, when='end'): """Compute the number of periodic payments. @@ -366,7 +364,27 @@ def nper(rate, pmt, pv, fv=0, when='end'): fv_inner = np.atleast_1d(fv) when_inner = np.atleast_1d(when) - # TODO: Validate ``*_inner`` array shapes + # TODO: I don't like repeating myself this often, refactor into a function + # that checks all of the arrays at once. + if rate_inner.ndim != 1: + msg = "invalid shape for rates. Rate must be either a scalar or 1d array" + raise ValueError(msg) + + if pmt_inner.ndim != 1: + msg = "invalid shape for pmt. Payments must be either a scalar or 1d array" + raise ValueError(msg) + + if pv_inner.ndim != 1: + msg = "invalid shape for pv. Present value must be either a scalar or 1d array" + raise ValueError(msg) + + if fv_inner.ndim != 1: + msg = "invalid shape for fv. Future value must be either a scalar or 1d array" + raise ValueError(msg) + + if when_inner.ndim != 1: + msg = "invalid shape for when. When must be either a scalar or 1d array" + raise ValueError(msg) out_shape = _get_output_array_shape( rate_inner, pmt_inner, pv_inner, fv_inner, when_inner