Skip to content

Commit

Permalink
fix: default value None for dict
Browse files Browse the repository at this point in the history
  • Loading branch information
madtoinou committed Nov 3, 2023
1 parent a797b13 commit a3817e7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
5 changes: 4 additions & 1 deletion darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ def _optimized_historical_forecasts(
last_points_only: bool = True,
verbose: bool = False,
show_warnings: bool = True,
predict_kwargs: Dict[str, Any] = {},
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand All @@ -1122,6 +1122,9 @@ def _optimized_historical_forecasts(
allow_autoregression=False,
)

if predict_kwargs is None:
predict_kwargs = dict()

# TODO: move the loop here instead of duplicated code in each sub-routine?
if last_points_only:
return _optimized_historical_forecasts_last_points_only(
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,7 +2035,7 @@ def _optimized_historical_forecasts(
last_points_only: bool = True,
verbose: bool = False,
show_warnings: bool = True,
predict_kwargs: Dict[str, Any] = {},
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _optimized_historical_forecasts(
last_points_only: bool = True,
show_warnings: bool = True,
verbose: bool = False,
predict_kwargs: Dict[str, Any] = {},
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand All @@ -41,6 +41,9 @@ def _optimized_historical_forecasts(
Rely on _check_optimizable_historical_forecasts() to check that the assumptions are verified.
"""
if predict_kwargs is None:
predict_kwargs = dict()

bounds = []
for idx, series_ in enumerate(series):
past_covariates_ = past_covariates[idx] if past_covariates is not None else None
Expand Down

0 comments on commit a3817e7

Please sign in to comment.