-
Notifications
You must be signed in to change notification settings - Fork 952
Description
Describe the bug
When fitting a torch based model with checkpoints and then using save_model and load_model to persist it will result in an error during loading because the model will contain a reference to the checkpoints and tries to load the last checkpoint during prediction.
To Reproduce
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel
from darts.datasets import AirPassengersDataset
import pandas as pd
# Read data:
series = AirPassengersDataset().load()
# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101"))
# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)
model_name = "Air_RNN"
my_model = RNNModel(
n_epochs=20,
model_name=model_name,
random_state=42,
input_chunk_length=14,
training_length=20,
save_checkpoints=True,
)
my_model.fit(
train_transformed,
val_series=val_transformed,
)
best_model = RNNModel.load_from_checkpoint(model_name, best=True)
best_model.save_model("best_model.pth.tar")
# delete logs
!rm -rf darts_logs
# reload model
best_model = RNNModel.load_model("best_model.pth.tar")
# this will fail because best_model.load_ckpt_path != None but the checkpoints are deleted
best_model.predict(1, val_transformed)
# Raises FileNotFoundError: Checkpoint at ....ckpt not found. Aborting training.
Expected behavior
loaded model predicts without needing the checkpoint.
System (please complete the following information):
- Python version: 3.7
- darts version 0.17.1
Additional context
For us this is somewhat critical that the model can not be loaded outside its training context. However the easy workaround seems to be to just set best_model.load_ckpt_path = None
From my limited perspective there is also no reason why the checkpoint should be needed for prediction in the first place. I can prepare a PR that just skips loading it if that helps.