Skip to content

Commit 6a5c946

Browse files
authored
fix: rich progress bar error when resume training (#21000)
1 parent 5d89996 commit 6a5c946

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,11 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible:
447447
visible=visible,
448448
)
449449

450+
def _initialize_train_progress_bar_id(self) -> None:
451+
total_batches = self.total_train_batches
452+
train_description = self._get_train_description(self.trainer.current_epoch)
453+
self.train_progress_bar_id = self._add_task(total_batches, train_description)
454+
450455
def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
451456
if self.progress is not None and self.is_enabled:
452457
assert progress_bar_id is not None
@@ -531,6 +536,9 @@ def on_train_batch_end(
531536
batch: Any,
532537
batch_idx: int,
533538
) -> None:
539+
if not self.is_disabled and self.train_progress_bar_id is None:
540+
# can happen when resuming from a mid-epoch restart
541+
self._initialize_train_progress_bar_id()
534542
self._update(self.train_progress_bar_id, batch_idx + 1)
535543
self._update_metrics(trainer, pl_module)
536544
self.refresh()

0 commit comments

Comments
 (0)