diff --git a/test/torchaudio_unittest/prototype/functional/dsp_utils.py b/test/torchaudio_unittest/prototype/functional/dsp_utils.py index fb0300a9d6..44c0cac3c3 100644 --- a/test/torchaudio_unittest/prototype/functional/dsp_utils.py +++ b/test/torchaudio_unittest/prototype/functional/dsp_utils.py @@ -1,4 +1,5 @@ import numpy as np +import numpy.typing as npt def oscillator_bank( @@ -43,8 +44,8 @@ def freq_ir(magnitudes): def exp_sigmoid( - input: np.ndarray, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7 -) -> np.ndarray: + input: npt.NDArray, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7 +) -> npt.NDArray: """Exponential Sigmoid pointwise nonlinearity (Numpy version). Implements the equation: ``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``