Skip to content

Commit 0e5b5ad

Browse files
authored
Feat/output chunk length regression model (#761)
* fix multicollinearity in regression ensemble model tests causing exploding coefficients * reset seed to intial value * add output_chunk_length parameter to regression model * add output_chunk_length to fit method of regressionmodel * add check if model support multi output regression natively * remove _shift_matrices test * update the LightGBMModel * update linear regression model * update random forest regression model * update LightGBMModel docstring * use dict for lags in regressionmodel and adjust all models and tests accordingly * reformat regression_ensemble_model using pre-commit * reformat test_regression_models with pre-commit * shorten comment line length * remove unused import to pass flake8 * reformat with black * reformat with black * update docstring of _create_lagged_data * reformat using black * improve error message when unable to build any samples to fit and when input_dim doesn't match * return self at the end of fit() in regressionmodel * remove numpydoc type hints and add n_jobs_multioutput_wrapper parameter to fit() * add comments
1 parent c2d91e0 commit 0e5b5ad

File tree

7 files changed

+725
-602
lines changed

7 files changed

+725
-602
lines changed

darts/models/forecasting/gradient_boosted_model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
LightGBM Model
33
--------------
44
5-
This is a LightGBM implementation of Gradient Boosted Trees algorightm.
5+
This is a LightGBM implementation of Gradient Boosted Trees algorithm.
66
77
To enable LightGBM support in Darts, follow the detailed install instructions for LightGBM in the README:
88
https://github.com/unit8co/darts/blob/master/README.md
@@ -23,7 +23,8 @@ def __init__(
2323
lags: Union[int, list] = None,
2424
lags_past_covariates: Union[int, List[int]] = None,
2525
lags_future_covariates: Union[Tuple[int, int], List[int]] = None,
26-
**kwargs
26+
output_chunk_length: int = 1,
27+
**kwargs,
2728
):
2829
"""Light Gradient Boosted Model
2930
@@ -41,6 +42,10 @@ def __init__(
4142
given the last `past` lags in the past are used (inclusive, starting from lag -1) along with the first
4243
`future` future lags (starting from 0 - the prediction time - up to `future - 1` included). Otherwise a list
4344
of integers with lags is required.
45+
output_chunk_length
46+
Number of time steps predicted at once by the internal regression model. Does not have to equal the forecast
47+
horizon `n` used in `predict()`. However, setting `output_chunk_length` equal to the forecast horizon may
48+
be useful if the covariates don't extend far enough into the future.
4449
**kwargs
4550
Additional keyword arguments passed to `lightgbm.LGBRegressor`.
4651
"""
@@ -50,13 +55,12 @@ def __init__(
5055
lags=lags,
5156
lags_past_covariates=lags_past_covariates,
5257
lags_future_covariates=lags_future_covariates,
58+
output_chunk_length=output_chunk_length,
5359
model=lgb.LGBMRegressor(**kwargs),
5460
)
5561

5662
def __str__(self):
57-
return "LGBModel(lags={}, lags_past={}, lags_future={})".format(
58-
self.lags, self.lags_past_covariates, self.lags_future_covariates
59-
)
63+
return f"LGBModel(lags={self.lags})"
6064

6165
def fit(
6266
self,
@@ -67,7 +71,7 @@ def fit(
6771
val_past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
6872
val_future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
6973
max_samples_per_ts: Optional[int] = None,
70-
**kwargs
74+
**kwargs,
7175
):
7276
"""
7377
Fits/trains the model using the provided list of features time series and the target time series.
@@ -109,7 +113,7 @@ def fit(
109113
past_covariates=past_covariates,
110114
future_covariates=future_covariates,
111115
max_samples_per_ts=max_samples_per_ts,
112-
**kwargs
116+
**kwargs,
113117
)
114118

115119
return self

darts/models/forecasting/linear_regression_model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
lags: Union[int, list] = None,
2020
lags_past_covariates: Union[int, List[int]] = None,
2121
lags_future_covariates: Union[Tuple[int, int], List[int]] = None,
22+
output_chunk_length: int = 1,
2223
**kwargs,
2324
):
2425
"""Linear regression model.
@@ -37,6 +38,10 @@ def __init__(
3738
given the last `past` lags in the past are used (inclusive, starting from lag -1) along with the first
3839
`future` future lags (starting from 0 - the prediction time - up to `future - 1` included). Otherwise a list
3940
of integers with lags is required.
41+
output_chunk_length
42+
Number of time steps predicted at once by the internal regression model. Does not have to equal the forecast
43+
horizon `n` used in `predict()`. However, setting `output_chunk_length` equal to the forecast horizon may
44+
be useful if the covariates don't extend far enough into the future.
4045
**kwargs
4146
Additional keyword arguments passed to `sklearn.linear_model.LinearRegression`.
4247
"""
@@ -45,13 +50,9 @@ def __init__(
4550
lags=lags,
4651
lags_past_covariates=lags_past_covariates,
4752
lags_future_covariates=lags_future_covariates,
53+
output_chunk_length=output_chunk_length,
4854
model=LinearRegression(**kwargs),
4955
)
5056

5157
def __str__(self):
52-
return (
53-
f"LinearRegression(lags={self.lags}, "
54-
f"lags_past_covariates={self.lags_past_covariates}, "
55-
f"lags_historical_covariates={self.lags_historical_covariates}, "
56-
f"lags_future_covariates={self.lags_future_covariates})"
57-
)
58+
return f"LinearRegression(lags={self.lags})"

darts/models/forecasting/random_forest.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
lags: Union[int, list] = None,
2929
lags_past_covariates: Union[int, List[int]] = None,
3030
lags_future_covariates: Union[Tuple[int, int], List[int]] = None,
31+
output_chunk_length: int = 1,
3132
n_estimators: Optional[int] = 100,
3233
max_depth: Optional[int] = None,
3334
**kwargs,
@@ -48,6 +49,10 @@ def __init__(
4849
given the last `past` lags in the past are used (inclusive, starting from lag -1) along with the first
4950
`future` future lags (starting from 0 - the prediction time - up to `future - 1` included). Otherwise a list
5051
of integers with lags is required.
52+
output_chunk_length
53+
Number of time steps predicted at once by the internal regression model. Does not have to equal the forecast
54+
horizon `n` used in `predict()`. However, setting `output_chunk_length` equal to the forecast horizon may
55+
be useful if the covariates don't extend far enough into the future.
5156
n_estimators : int
5257
The number of trees in the forest.
5358
max_depth : int
@@ -66,14 +71,12 @@ def __init__(
6671
lags=lags,
6772
lags_past_covariates=lags_past_covariates,
6873
lags_future_covariates=lags_future_covariates,
74+
output_chunk_length=output_chunk_length,
6975
model=RandomForestRegressor(**kwargs),
7076
)
7177

7278
def __str__(self):
7379
return (
7480
f"RandomForest(lags={self.lags}, "
75-
f"lags_past_covariates={self.lags_past_covariates}, "
76-
f"lags_historical_covariates={self.lags_historical_covariates}, "
77-
f"lags_future_covariates={self.lags_future_covariates}, "
7881
f"n_estimators={self.n_estimators}, max_depth={self.max_depth})"
7982
)

darts/models/forecasting/regression_ensemble_model.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
from typing import Optional, List, Union, Sequence, Tuple
88
from darts.timeseries import TimeSeries
9-
from darts.logging import get_logger, raise_if
9+
from darts.logging import get_logger, raise_if, raise_if_not
1010

1111
from darts.models.forecasting.forecasting_model import (
1212
ForecastingModel,
@@ -58,20 +58,12 @@ def __init__(
5858
lags_future_covariates=[0], model=regression_model
5959
)
6060

61-
raise_if(
62-
regression_model.lags is not None
63-
and regression_model.lags_historical_covariates is not None
64-
and regression_model.lags_past_covariates is not None
65-
and regression_model.lags_future_covariates != [0],
66-
(
67-
f"`lags`, `lags_historical_covariates` and `lags_past_covariates` "
68-
f"of regression model must be `None` "
69-
f"and `lags_future_covariates` must be [0]. Given:\n"
70-
f"`lags`: {regression_model.lags}, "
71-
f"`lags_historical_covariates`: {regression_model.lags_historical_covariates}, "
72-
f"`lags_past_covariates`: {regression_model.lags} and "
73-
f"`lags_future_covariates`: {regression_model.lags_future_covariates}."
74-
),
61+
# check lags of the regression model
62+
raise_if_not(
63+
regression_model.lags == {"future": [0]},
64+
f"`lags` and `lags_past_covariates` of regression model must be `None`"
65+
f"and `lags_future_covariates` must be [0]. Given:\n"
66+
f"{regression_model.lags}",
7567
)
7668

7769
self.regression_model = regression_model

0 commit comments

Comments
 (0)