-
Notifications
You must be signed in to change notification settings - Fork 952
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
Describe the bug
First of all we train a model with TFTModel
with 30 epochs. Then, we aim to do transfer learning by re-training the previous model loading it from last checkpoint. Then, we execute the .fit(..,epochs=additional_n_epochs)
but an error occurs:
File "<string>", line 1, in <module>
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 771, in fit
return self.fit_from_dataset(
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 930, in fit_from_dataset
self._train(train_loader, val_loader)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 952, in _train
self.trainer.fit(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
self._call_and_handle_interrupt(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1232, in _run
self._checkpoint_connector.restore_training_state()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 199, in restore_training_state
self.restore_loops()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 293, in restore_loops
raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: You restored a checkpoint with current_epoch=29, but you have set Trainer(max_epochs=5).
To Reproduce
additional_n_epochs=5
my_model = TFTModel.load_from_checkpoint(mymodelname, work_dir=mymodeldir, best=False)
my_model.fit(...,epochs=additional_n_epochs)
Expected behavior
We aim to get a training process departing from the epoch of last checkpoint and continue until the total number of epochs is: my_model.n_epochs + additional_n_epochs
.
System (please complete the following information):
- Python version: 3.9
- darts version 0.18.0
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested