Skip to content

Commit 274a66d

Browse files
authored
Merge branch 'master' into feat/python310
2 parents b8a4cc5 + b530dcd commit 274a66d

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

darts/models/forecasting/forecasting_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,8 @@ def gridsearch(
680680
If `True`, uses the comparison with the fitted values.
681681
Raises an error if ``fitted_values`` is not an attribute of `model_class`.
682682
metric
683-
A function that takes two TimeSeries instances as inputs and returns a float error value.
683+
A function that takes two TimeSeries instances as inputs (actual and prediction, in this order),
684+
and returns a float error value.
684685
reduction
685686
A reduction function (mapping array to float) describing how to aggregate the errors obtained
686687
on the different validation series when backtesting. By default it'll compute the mean of errors.
@@ -764,7 +765,7 @@ def _evaluate_combination(param_combination) -> float:
764765
fitted_values = TimeSeries.from_times_and_values(
765766
series.time_index, model.fitted_values
766767
)
767-
error = metric(fitted_values, series)
768+
error = metric(series, fitted_values)
768769
elif val_series is None: # expanding window mode
769770
error = model.backtest(
770771
series=series,
@@ -787,7 +788,7 @@ def _evaluate_combination(param_combination) -> float:
787788
future_covariates,
788789
num_samples=1,
789790
)
790-
error = metric(pred, val_series)
791+
error = metric(val_series, pred)
791792

792793
return float(error)
793794

0 commit comments

Comments
 (0)