Skip to content

Commit 8cc12c8

Browse files
Update torch_forecasting_model.py
Corrected file saving process for checkpoint files (ckpt) to filter out occurrences of the string '.pt' from the previous file path."
1 parent 42ef14a commit 8cc12c8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

darts/models/forecasting/torch_forecasting_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1579,7 +1579,7 @@ def save(self, path: Optional[str] = None) -> None:
15791579
torch.save(self, f_out)
15801580

15811581
# save the LightningModule checkpoint
1582-
path_ptl_ckpt = path + ".ckpt"
1582+
path_ptl_ckpt = path.replace(".pt", "") + ".ckpt"
15831583
if self.trainer is not None:
15841584
self.trainer.save_checkpoint(path_ptl_ckpt)
15851585
# TODO: keep track of PyTorch Lightning to see if they implement model checkpoint saving

0 commit comments

Comments
 (0)