Skip to content

[BUG] Torch model with checkpoints is not saved independent of checkpoint #820

@tharwan

Description

@tharwan

Describe the bug
When fitting a torch based model with checkpoints and then using save_model and load_model to persist it will result in an error during loading because the model will contain a reference to the checkpoints and tries to load the last checkpoint during prediction.

To Reproduce

from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel
from darts.datasets import AirPassengersDataset
import pandas as pd

# Read data:
series = AirPassengersDataset().load()

# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101"))

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

model_name = "Air_RNN"
my_model = RNNModel(
    n_epochs=20,
    model_name=model_name,
    random_state=42,
    input_chunk_length=14,
    training_length=20,
    save_checkpoints=True,
)

my_model.fit(
    train_transformed,
    val_series=val_transformed,
)

best_model = RNNModel.load_from_checkpoint(model_name, best=True)
best_model.save_model("best_model.pth.tar")

# delete logs
!rm -rf darts_logs

# reload model
best_model = RNNModel.load_model("best_model.pth.tar")

# this will fail because best_model.load_ckpt_path != None but the checkpoints are deleted
best_model.predict(1, val_transformed)
# Raises FileNotFoundError: Checkpoint at ....ckpt not found. Aborting training.

Expected behavior
loaded model predicts without needing the checkpoint.

System (please complete the following information):

  • Python version: 3.7
  • darts version 0.17.1

Additional context
For us this is somewhat critical that the model can not be loaded outside its training context. However the easy workaround seems to be to just set best_model.load_ckpt_path = None

From my limited perspective there is also no reason why the checkpoint should be needed for prediction in the first place. I can prepare a PR that just skips loading it if that helps.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtriageIssue waiting for triaging

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions