Skip to content

Commit

Permalink
mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeitsperre committed Apr 22, 2024
1 parent 689d3ae commit 5d0dac6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 33 deletions.
47 changes: 25 additions & 22 deletions xclim/indices/fire/_cffwis.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@
"overwintering_drought_code",
]

default_params = dict(
temp_start_thresh=(12, "degC"),
temp_end_thresh=(5, "degC"),
default_params: dict[str, int | float | tuple[float, str]] = dict(
temp_start_thresh=(12.0, "degC"),
temp_end_thresh=(5.0, "degC"),
snow_thresh=(0.01, "m"),
temp_condition_days=3,
snow_condition_days=3,
Expand Down Expand Up @@ -388,9 +388,13 @@ def _duff_moisture_code(


@vectorize(nopython=True)
def _drought_code(
t: np.ndarray, p: np.ndarray, mth: np.ndarray, lat: float, dc0: float
) -> np.ndarray: # pragma: no cover
def _drought_code( # pragma: no cover
t: np.ndarray,
p: np.ndarray,
mth: np.ndarray,
lat: float,
dc0: float,
) -> np.ndarray:
"""Compute the drought code over one time step.
Parameters
Expand All @@ -411,10 +415,10 @@ def _drought_code(
array_like
Drought code at the current timestep
"""
fl = _day_length_factor(lat, mth)
fl = _day_length_factor(lat, mth) # type: ignore

if t < -2.8:
t = -2.8
t = -2.8 # type: ignore
pe = (0.36 * (t + 2.8) + fl) / 2 # *Eq.22*#
pe = max(pe, 0.0)

Expand All @@ -431,7 +435,7 @@ def _drought_code(
dc = pe
else: # f p <= 2.8:
dc = dc0 + pe
return dc
return dc # type: ignore


def initial_spread_index(ws: np.ndarray, ffmc: np.ndarray) -> np.ndarray:
Expand All @@ -451,7 +455,7 @@ def initial_spread_index(ws: np.ndarray, ffmc: np.ndarray) -> np.ndarray:
"""
mo = 147.2 * (101.0 - ffmc) / (59.5 + ffmc) # *Eq.1*#
ff = 19.1152 * np.exp(mo * -0.1386) * (1.0 + (mo**5.31) / 49300000.0) # *Eq.25*#
isi = ff * np.exp(0.05039 * ws) # *Eq.26*#
isi: np.ndarray = ff * np.exp(0.05039 * ws) # *Eq.26*#
return isi


Expand Down Expand Up @@ -503,7 +507,7 @@ def fire_weather_index(isi, bui):
return fwi


def daily_severity_rating(fwi: np.ndarray) -> np.ndarry:
def daily_severity_rating(fwi: np.ndarray) -> np.ndarray:
"""Daily severity rating.
Parameters
Expand Down Expand Up @@ -548,6 +552,7 @@ def _overwintering_drought_code(DCf, wpr, a, b, minDC): # pragma: no cover
# SECTION 2 : Iterators


# FIXME: default_params should be supplied within the logic of the function.
def _fire_season(
tas: np.ndarray,
snd: np.ndarray | None = None,
Expand Down Expand Up @@ -1056,15 +1061,16 @@ def fire_weather_ufunc( # noqa: C901
)
# Arg order : tas, pr, hurs, sfcWind, snd, mth, lat, season_mask, dc0, dmc0, ffmc0, winter_pr
# 0 1 2 3 4 5 6 7 8 9 10 11
args = [None] * 12
input_core_dims = [[]] * 12
args: list[xr.DataArray | None] = [None] * 12
input_core_dims: list[list[str | None]] = [[]] * 12

# Verification of all arguments
for i, (arg, name, usedby, has_time_dim) in enumerate(needed_args):
if any([ind in indexes + [season_method] for ind in usedby]):
if arg is None:
raise TypeError(
f"Missing input argument {name} for index combination {indexes} with fire season method '{season_method}'"
f"Missing input argument {name} for index combination {indexes} "
f"with fire season method '{season_method}'."
)
args[i] = arg
input_core_dims[i] = ["time"] if has_time_dim else []
Expand All @@ -1078,17 +1084,14 @@ def fire_weather_ufunc( # noqa: C901
raise ValueError("'dry_start' must be one of None, 'CFS' or 'GFWED'.")

# Always pass the previous codes.
if dc0 is None:
dc0 = xr.full_like(tas.isel(time=0), np.nan)
if dmc0 is None:
dmc0 = xr.full_like(tas.isel(time=0), np.nan)
if ffmc0 is None:
ffmc0 = xr.full_like(tas.isel(time=0), np.nan)
args[8:11] = [dc0, dmc0, ffmc0]
_dc0 = xr.full_like(tas.isel(time=0), np.nan) if dc0 is None else dc0
_dmc0 = xr.full_like(tas.isel(time=0), np.nan) if dmc0 is None else dmc0
_ffmc0 = xr.full_like(tas.isel(time=0), np.nan) if ffmc0 is None else ffmc0
args[8:11] = [_dc0, _dmc0, _ffmc0]

# Output config from the current indexes list
outputs = indexes
output_dtypes = [tas.dtype] * len(indexes)
output_dtypes: list[np.dtype] = [tas.dtype] * len(indexes)
output_core_dims = len(indexes) * [("time",)]

if season_mask is not None:
Expand Down
18 changes: 9 additions & 9 deletions xclim/indices/fire/_ffdi.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
nopython=True,
cache=True,
)
def _keetch_byram_drought_index(p, t, pa, kbdi0, kbdi: float): # pragma: no cover
def _keetch_byram_drought_index(p, t, pa, kbdi0, kbdi): # pragma: no cover
"""Compute the Keetch-Byram drought (KBDI) index.
Parameters
Expand Down Expand Up @@ -239,7 +239,7 @@ def keetch_byram_drought_index(
:cite:cts:`ffdi-keetch_1968,ffdi-finkele_2006,ffdi-holgate_2017,ffdi-dolling_2005`
"""

def _keetch_byram_drought_index_pass(pr, tasmax, pr_annual, kbdi0):
def _keetch_byram_drought_index_pass(_pr, _tasmax, _pr_annual, _kbdi0):
"""Pass inputs on to guvectorized function `_keetch_byram_drought_index`.
This function is actually only required as `xr.apply_ufunc` will not receive
Expand All @@ -249,7 +249,7 @@ def _keetch_byram_drought_index_pass(pr, tasmax, pr_annual, kbdi0):
--------
DO NOT CALL DIRECTLY, use `keetch_byram_drought_index` instead.
"""
return _keetch_byram_drought_index(pr, tasmax, pr_annual, kbdi0)
return _keetch_byram_drought_index(_pr, _tasmax, _pr_annual, _kbdi0)

pr = convert_units_to(pr, "mm/day", context="hydro")
tasmax = convert_units_to(tasmax, "C")
Expand All @@ -259,7 +259,7 @@ def _keetch_byram_drought_index_pass(pr, tasmax, pr_annual, kbdi0):
else:
kbdi0 = xr.full_like(pr.isel(time=0), 0)

kbdi = xr.apply_ufunc(
kbdi: xr.DataArray = xr.apply_ufunc(
_keetch_byram_drought_index_pass,
pr,
tasmax,
Expand All @@ -270,7 +270,7 @@ def _keetch_byram_drought_index_pass(pr, tasmax, pr_annual, kbdi0):
dask="parallelized",
output_dtypes=[pr.dtype],
)
kbdi.attrs["units"] = "mm/day"
kbdi = kbdi.assign_attrs(units="mm/day")
return kbdi


Expand Down Expand Up @@ -317,7 +317,7 @@ def griffiths_drought_factor(
:cite:cts:`ffdi-griffiths_1999,ffdi-finkele_2006,ffdi-holgate_2017`
"""

def _griffiths_drought_factor_pass(pr, smd, lim):
def _griffiths_drought_factor_pass(_pr, _smd, _lim):
"""Pass inputs on to guvectorized function `_griffiths_drought_factor`.
This function is actually only required as xr.apply_ufunc will not receive
Expand All @@ -327,7 +327,7 @@ def _griffiths_drought_factor_pass(pr, smd, lim):
--------
DO NOT CALL DIRECTLY, use `griffiths_drought_factor` instead.
"""
return _griffiths_drought_factor(pr, smd, lim)
return _griffiths_drought_factor(_pr, _smd, _lim)

pr = convert_units_to(pr, "mm/day", context="hydro")
smd = convert_units_to(smd, "mm/day")
Expand All @@ -339,7 +339,7 @@ def _griffiths_drought_factor_pass(pr, smd, lim):
else:
raise ValueError(f"{limiting_func} is not a valid input for `limiting_func`")

df = xr.apply_ufunc(
df: xr.DataArray = xr.apply_ufunc(
_griffiths_drought_factor_pass,
pr,
smd,
Expand All @@ -349,7 +349,7 @@ def _griffiths_drought_factor_pass(pr, smd, lim):
dask="parallelized",
output_dtypes=[pr.dtype],
)
df.attrs["units"] = ""
df = df.assign_attrs(units="")

# First non-zero entry is at the 19th time point since df is calculated
# from a 20-day rolling window. Make prior points NaNs.
Expand Down
6 changes: 4 additions & 2 deletions xclim/indices/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,9 @@ def _get_zone(_da):
return zones


def detrend(ds, dim="time", deg=1) -> xr.Dataset | xr.DataArray:
def detrend(
ds: xr.DataArray | xr.Dataset, dim="time", deg=1
) -> xr.DataArray | xr.Dataset:
"""Detrend data along a given dimension computing a polynomial trend of a given order.
Parameters
Expand All @@ -1075,7 +1077,7 @@ def detrend(ds, dim="time", deg=1) -> xr.Dataset | xr.DataArray:
Returns
-------
detrended : xr.Dataset or xr.DataArray
xr.Dataset or xr.DataArray
Same as `ds`, but with its trend removed (subtracted).
"""
if isinstance(ds, xr.Dataset):
Expand Down

0 comments on commit 5d0dac6

Please sign in to comment.