Skip to content

Commit 3bdf8ee

Browse files
authored
docs: Add clarification on checkpoint save location (#21214)
1 parent 252bc21 commit 3bdf8ee

File tree

1 file changed

+36
-3
lines changed

1 file changed

+36
-3
lines changed

docs/source-pytorch/common/checkpointing_basic.rst

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,45 @@ Lightning automatically saves a checkpoint for you in your current working direc
5858
# simply by using the Trainer you get automatic checkpointing
5959
trainer = Trainer()
6060
61-
To change the checkpoint path use the `default_root_dir` argument:
61+
62+
Checkpoint save location
63+
========================
64+
65+
The location where checkpoints are saved depends on whether you have configured a logger:
66+
67+
**Without a logger**, checkpoints are saved to the ``default_root_dir``:
68+
69+
.. code-block:: python
70+
71+
# saves checkpoints to 'some/path/checkpoints/'
72+
trainer = Trainer(default_root_dir="some/path/", logger=False)
73+
74+
**With a logger**, checkpoints are saved to the logger's directory, **not** to ``default_root_dir``:
6275

6376
.. code-block:: python
6477
65-
# saves checkpoints to 'some/path/' at every epoch end
66-
trainer = Trainer(default_root_dir="some/path/")
78+
from lightning.pytorch.loggers import CSVLogger
79+
80+
# checkpoints will be saved to 'logs/my_experiment/version_0/checkpoints/'
81+
# NOT to 'some/path/checkpoints/'
82+
trainer = Trainer(
83+
default_root_dir="some/path/", # This will be ignored for checkpoints!
84+
logger=CSVLogger("logs", "my_experiment")
85+
)
86+
87+
To explicitly control the checkpoint location when using a logger, use the
88+
:class:`~lightning.pytorch.callbacks.ModelCheckpoint` callback:
89+
90+
.. code-block:: python
91+
92+
from lightning.pytorch.callbacks import ModelCheckpoint
93+
94+
# explicitly set checkpoint directory
95+
checkpoint_callback = ModelCheckpoint(dirpath="my/custom/checkpoint/path/")
96+
trainer = Trainer(
97+
logger=CSVLogger("logs", "my_experiment"),
98+
callbacks=[checkpoint_callback]
99+
)
67100
68101
69102
----

0 commit comments

Comments
 (0)