File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -680,7 +680,8 @@ def gridsearch(
680
680
If `True`, uses the comparison with the fitted values.
681
681
Raises an error if ``fitted_values`` is not an attribute of `model_class`.
682
682
metric
683
- A function that takes two TimeSeries instances as inputs and returns a float error value.
683
+ A function that takes two TimeSeries instances as inputs (actual and prediction, in this order),
684
+ and returns a float error value.
684
685
reduction
685
686
A reduction function (mapping array to float) describing how to aggregate the errors obtained
686
687
on the different validation series when backtesting. By default it'll compute the mean of errors.
@@ -764,7 +765,7 @@ def _evaluate_combination(param_combination) -> float:
764
765
fitted_values = TimeSeries .from_times_and_values (
765
766
series .time_index , model .fitted_values
766
767
)
767
- error = metric (fitted_values , series )
768
+ error = metric (series , fitted_values )
768
769
elif val_series is None : # expanding window mode
769
770
error = model .backtest (
770
771
series = series ,
@@ -787,7 +788,7 @@ def _evaluate_combination(param_combination) -> float:
787
788
future_covariates ,
788
789
num_samples = 1 ,
789
790
)
790
- error = metric (pred , val_series )
791
+ error = metric (val_series , pred )
791
792
792
793
return float (error )
793
794
You can’t perform that action at this time.
0 commit comments