Skip to content

[BUG] TFTModel.load_from_checkpoint and .fit() is returning an error. #1090

@criscapdechoy

Description

@criscapdechoy

Describe the bug
First of all we train a model with TFTModel with 30 epochs. Then, we aim to do transfer learning by re-training the previous model loading it from last checkpoint. Then, we execute the .fit(..,epochs=additional_n_epochs) but an error occurs:

File "<string>", line 1, in <module>
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
  return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 771, in fit
  return self.fit_from_dataset(
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
  return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 930, in fit_from_dataset
  self._train(train_loader, val_loader)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 952, in _train
  self.trainer.fit(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
  self._call_and_handle_interrupt(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
  return trainer_fn(*args, **kwargs)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
  results = self._run(model, ckpt_path=self.ckpt_path)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1232, in _run
  self._checkpoint_connector.restore_training_state()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 199, in restore_training_state
  self.restore_loops()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 293, in restore_loops
  raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: You restored a checkpoint with current_epoch=29, but you have set Trainer(max_epochs=5).

To Reproduce

additional_n_epochs=5
my_model = TFTModel.load_from_checkpoint(mymodelname, work_dir=mymodeldir, best=False)
my_model.fit(...,epochs=additional_n_epochs)

Expected behavior
We aim to get a training process departing from the epoch of last checkpoint and continue until the total number of epochs is: my_model.n_epochs + additional_n_epochs .

System (please complete the following information):

  • Python version: 3.9
  • darts version 0.18.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions