Skip to content

ModelCheckpoint Callback not working/saving unless save_on_train_epoch_end is enabled True which considerably slows down training #20200

@snknitin

Description

@snknitin

Bug description

Problem

  • The checkpoint folder doesn't get created at all without the save_on_train_epoch_end: True.
  • If i set it to false, it runs super fast but doesn't save any checkpoint even if save_last:True
  • If i set it to true and have save_last also true then i get a checkpoints folder and 2 checkpoints for the best and last
  • If I enable the save_on_train_epoch_end and try to use every_n_epochs and matched with trainer : check_val_every_n_epochs it still goes very slow(same time) but gives me the 99th check point of 499 or 999 depending on choosing 100,500,1000 rather than monitoring the log I chose and giving me the best one, which i think is expected,
  • I do not have an on_train_epoch_end but i created one and moved the log i am monitoring into that, but it gives the same behaviour and is not creating a checkpoint unless save_on_train_epoch_end is True.
  • creating a val_step is of no use cause without a val_dataloader it skips and if i copy the train_dataloader and name it val it jus throws errors in sanity check and isn't working.

What version are you seeing the problem on?

v2.1, v2.2, v2.3, v2.4, master

How to reproduce the bug

import time

import torch
from torch import nn, optim
import lightning as pl
from torch.utils.data import DataLoader, random_split, Dataset
from torch.nn import functional as F
from torchmetrics import MeanMetric, SumMetric
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision import datasets
import os
import collections
from collections import OrderedDict
import numpy as np
import gym
import argparse


def process_state(state):
    if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], dict):
        return state[0]
    else:
        return state

class DQN(nn.Module):
    """
    Simple MLP network
    Args:
        obs_size: observation/state size of the environment
        n_actions: number of discrete actions available in the environment
        hidden_size: size of hidden layers
    """

    def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x):
        return self.net(x.float())


# Named tuple for storing experience steps gathered in training
Experience = collections.namedtuple(
    'Experience', field_names=['state', 'action', 'reward',
                               'done', 'new_state'])

class ReplayBuffer:
    """
    Replay Buffer for storing past experiences allowing the agent to learn from them
    Args:
        capacity: size of the buffer
    """

    def __init__(self, capacity: int) -> None:
        self.buffer = collections.deque(maxlen=capacity)

    def __len__(self) -> None:
        return len(self.buffer)

    def append(self, experience: Experience) -> None:
        """
        Add experience to the buffer
        Args:
            experience: tuple (state, action, reward, done, new_state)
        """
        self.buffer.append(experience)

    def sample(self, batch_size: int):  # -> Tuple:
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        experiences = [self.buffer[idx] for idx in indices]

        states = np.array([exp.state for exp in experiences])
        actions = np.array([exp.action for exp in experiences])
        rewards = np.array([exp.reward for exp in experiences], dtype=np.float32)
        dones = np.array([exp.done for exp in experiences], dtype=np.bool_)
        new_states = np.array([exp.new_state for exp in experiences])

        return states, actions, rewards, dones, new_states


class RLDataset(torch.utils.data.IterableDataset):
    """
    Iterable Dataset containing the ReplayBuffer
    which will be updated with new experiences during training
    Args:
        buffer: replay buffer
        sample_size: number of experiences to sample at a time
    """

    def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
        self.buffer = buffer
        self.sample_size = sample_size

    def __iter__(self): # -> Tuple:
        states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
        for i in range(len(dones)):
            yield states[i], actions[i], rewards[i], dones[i], new_states[i]


class Agent:
    """
    Base Agent class handeling the interaction with the environment
    Args:
        env: training environment
        replay_buffer: replay buffer storing experiences
    """

    def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
        self.env = env
        self.replay_buffer = replay_buffer
        self.reset()
        self.state = self.env.reset()

    def process_state(self, state):
        if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], dict):
            return state[0]
        else:
            return state

    def reset(self) -> None:
        """ Resents the environment and updates the state"""
        self.state = self.process_state(self.env.reset())

    def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
        """
        Using the given network, decide what action to carry out
        using an epsilon-greedy policy
        Args:
            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device
        Returns:
            action
        """
        if np.random.random() < epsilon:
            action = self.env.action_space.sample()
        else:
            state = self.process_state(self.state)
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            state = state.to(device)

            q_values = net(state)
            _, action = torch.max(q_values, dim=1)
            action = int(action.item())

        return action

    @torch.no_grad()
    def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu'): # -> Tuple[float, bool]:
        """
        Carries out a single interaction step between the agent and the environment
        Args:
            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device
        Returns:
            reward, done
        """

        action = self.get_action(net, epsilon, device)

        new_state, reward, done, _, _ = self.env.step(action)

        # Process states
        state = np.array(self.process_state(self.state))
        new_state = np.array(self.process_state(new_state))

        exp = Experience(state, action, reward, done, new_state)

        self.replay_buffer.append(exp)

        self.state = new_state
        if done:
            self.reset()
        return reward, done


class DQNLightning(pl.LightningModule):
    """ Basic DQN Model """

    def __init__(self, hparams: argparse.Namespace) -> None:
        super().__init__()

        if isinstance(hparams, dict):
            # If hparams is already a dict (loaded from checkpoint)
            self.save_hyperparameters(hparams)
        else:
            # If hparams is a Namespace (initial creation)
            self.save_hyperparameters(vars(hparams))
        # self.hparams = hparams

        self.env = gym.make(self.hparams.env)
        obs_size = self.env.observation_space.shape[0]
        n_actions = self.env.action_space.n

        self.net = DQN(obs_size, n_actions)
        self.target_net = DQN(obs_size, n_actions)

        self.buffer = ReplayBuffer(self.hparams.replay_size)
        self.agent = Agent(self.env, self.buffer)

        self.total_reward = 0
        self.episode_reward = 0

        self.populate(self.hparams.warm_start_steps)

        self.train_loss = MeanMetric()
        self.avg_episodic_reward = MeanMetric()

        self.episode_reward = SumMetric()
        self.episode_rewards = [0]
        self.cumulative_step_reward = 0
        self.mavg_reward = 0

    def populate(self, steps: int = 1000) -> None:
        """
        Carries out several random steps through the environment to initially fill
        up the replay buffer with experiences
        Args:
            steps: number of random steps to populate the buffer with
        """
        for i in range(steps):
            self.agent.play_step(self.net, epsilon=1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Passes in a state x through the network and gets the q_values of each action as an output
        Args:
            x: environment state
        Returns:
            q values
        """
        output = self.net(x)
        return output

    def dqn_mse_loss(self, batch) -> torch.Tensor: # : Tuple[torch.Tensor, torch.Tensor]
        """
        Calculates the mse loss using a mini batch from the replay buffer
        Args:
            batch: current mini batch of replay data
        Returns:
            loss
        """
        states, actions, rewards, dones, next_states = batch
        # Convert actions to torch.int64
        actions = actions.long()
        state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

        with torch.no_grad():
            next_state_values = self.target_net(next_states).max(1)[0]
            next_state_values[dones] = 0.0
            next_state_values = next_state_values.detach()

        expected_state_action_values = next_state_values * self.hparams.gamma + rewards

        return nn.MSELoss()(state_action_values, expected_state_action_values)

    def on_train_start(self):
        # To ensure every run is proper from start to finish and not mid-way from populate buffer
        self.env.reset()
        # Reset metrics for next episode

        self.train_loss.reset()
        self.avg_episodic_reward.reset()

        # Reset metrics for next episode
        self.episode_reward.reset()
        self.episode_length.reset()

        self.episode_rewards = [0]
        self.cumulative_step_reward = 0
        self.mavg_reward = 0

    def training_step(self, batch, nb_batch): # : Tuple[torch.Tensor, torch.Tensor]
        """
        Carries out a single step through the environment to update the replay buffer.
        Then calculates loss based on the minibatch recieved
        Args:
            batch: current mini batch of replay data
            nb_batch: batch number
        Returns:
            Training loss and log metrics
        """
        epsilon = max(self.hparams.eps_end, self.hparams.eps_start -
                      self.global_step + 1 / self.hparams.eps_last_frame)

        # step through environment with agent
        reward, done = self.agent.play_step(self.net, epsilon, self.device)
        self.episode_reward += reward

        # calculates training loss
        loss = self.dqn_mse_loss(batch)
        self.train_loss(loss)
        # Update both reward metrics
        self.cumulative_step_reward += reward.item()
        self.episode_reward(reward)
        self.avg_episodic_reward(reward)

        # Log moving averages
        window_size = min(5, len(self.episode_rewards))
        self.mavg_reward = sum(self.episode_rewards[-window_size:]) / window_size
        cumulative_episode_reward = sum(self.episode_rewards)

        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/cumulative_step_reward", self.cumulative_step_reward, on_step=False, on_epoch=True)
        self.log("train/cumulative_episodic_reward", cumulative_episode_reward, on_step=False, on_epoch=True)
        # self.log("train/Moving_avg_5_ep_reward", self.mavg_reward, on_step=False, on_epoch=True)


        # if self.trainer.use_dp or self.trainer.use_ddp2:
        #     loss = loss.unsqueeze(0)


        # Soft update of target network
        if self.global_step % self.hparams.sync_rate == 0:
            self.target_net.load_state_dict(self.net.state_dict())

        reward = torch.tensor(reward, dtype=torch.float32).to(self.device).item()
        steps = torch.tensor(self.global_step).to(self.device)

        # Log environment step metrics
        self.log("env_step/reward", reward, on_step=False, on_epoch=True, prog_bar=False)
        self.log("env_step/loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True)

        log = {'loss': loss,
               'reward': reward,
               'steps': steps}

        if done:
            self.episode_count += 1
            episode_reward = self.episode_reward.compute()
            episode_avg_reward = self.avg_episodic_reward.compute()
            self.episode_rewards.append(episode_reward.item())

            # Log episode-level metrics only when an episode is done
            self.log("episode/total_reward", episode_reward.item(), on_step=False, on_epoch=True, prog_bar=True)
            self.log("episode/avg_reward", episode_avg_reward.item(), on_step=False, on_epoch=True, prog_bar=True)

            self.log("episode/length", self.episode_length.compute(), on_step=False, on_epoch=True)
            self.log("episode/avg_loss", self.train_loss.compute(), on_step=False, on_epoch=True)


            # Log custom episode count
            self.log("episode/count", self.episode_count, on_step=False, on_epoch=True)


            # Reset metrics for next episode
            self.episode_reward.reset()
            self.episode_length.reset()

            # Reset the environment for the next episode
            self.env.reset()


        return {'loss': loss, 'reward': reward}

    def on_train_epoch_end(self) -> None:
        "Lightning hook that is called when a training epoch ends."
        self.log("train/Moving_avg_5_ep_reward", self.mavg_reward, on_step=False, on_epoch=True)


    def configure_optimizers(self):  # -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr)
        return [optimizer]

    def train_dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = RLDataset(self.buffer, self.hparams.episode_length)
        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.hparams.batch_size,
                                )
        return dataloader

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch"""
        return batch[0].device.index if self.on_gpu else 'cpu'


def main(hparams) -> None:
    pl.seed_everything(1234)
    model = DQNLightning(hparams)

    experiment_dir = os.path.join(os.getcwd(), "../../src/checkpoints")
    callbacks = []
    goldstar_metric = "train/Moving_avg_5_ep_reward"
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
                    dirpath = experiment_dir,
                    filename= "epoch_{epoch:03d}",
                    monitor=goldstar_metric,
                    mode= "max",
                    save_last= True,
                    auto_insert_metric_name= False,
                    every_n_epochs= 900, # match with val_check_interval
                    save_on_train_epoch_end= False
                    )

    summary_callback = pl.callbacks.ModelSummary(max_depth=2)

    callbacks = [summary_callback, checkpoint_callback]

    # early_stopping_callback = pl.callbacks.EarlyStopping(monitor="train/total_reward", mode="max", patience=3)
    # callbacks.append(early_stopping_callback)


    trainer = pl.Trainer( accelerator="cpu",
        # gpus=1,
        # distributed_backend='dp',
        max_epochs=1000,
        #early_stop_callback=False,
        val_check_interval=900,
        callbacks=callbacks
    )

    trainer.fit(model)
    trainer.save_checkpoint("example.ckpt")

    # Load the model
    # new_model = DQNLightning.load_from_checkpoint(checkpoint_path="example.ckpt", hparams=hparams)

    # env = gym.make(new_model.hparams.env)
    # obs = process_state(env.reset())
    # done = False
    # for i in range(30000):
    #     q_values = new_model(torch.tensor([obs], dtype=torch.float32))
    #     _, action = torch.max(q_values, dim=1)
    #     action = int(action.item())
    #
    #     obs, rew, done, info,_ = env.step(action)
    #     env.render()
    #     # print(rew)
    #     if done:
    #         obs = process_state(env.reset())
    # env.close()


if __name__ == '__main__':
    torch.manual_seed(0)
    np.random.seed(0)

    start = time.time()

    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument("--env", type=str, default="CartPole-v1", help="gym environment tag")
    parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
    parser.add_argument("--sync_rate", type=int, default=10,
                        help="how many frames do we update the target network")
    parser.add_argument("--replay_size", type=int, default=1000,
                        help="capacity of the replay buffer")
    parser.add_argument("--warm_start_size", type=int, default=1000,
                        help="how many samples do we use to fill our buffer at the start of training")
    parser.add_argument("--eps_last_frame", type=int, default=1000,
                        help="what frame should epsilon stop decaying")
    parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
    parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
    parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
    parser.add_argument("--max_episode_reward", type=int, default=200,
                        help="max episode reward in the environment")
    parser.add_argument("--warm_start_steps", type=int, default=1000,
                        help="max episode reward in the environment")

    args = parser.parse_args()

    main(args)

    end = time.time()
    print("Time taken :",end-start)

Error messages and logs

No Error messages as such. The checkpoint folder is just not created if the flag is not set to True which doesn't seem like the correct behavior

Environment

name: RL-inv-env

channels:
  - pytorch
  - conda-forge
  - defaults

dependencies:
  - python=3.10
  - pytorch=2.*
  - torchvision=0.*
  - lightning=2.*
  - torchmetrics=0.*
  - hydra-core=1.*
  - rich=13.*
  - pre-commit=3.*
  - pytest=7.*
  - numpy=1.26.4
  - wandb=0.17.2
  - tensorboard
  - ipykernel
  - pip>=23
  - pip:
      - hydra-optuna-sweeper
      - hydra-colorlog
      - rootutils
      - gym
      - stable-baselines3[extra]
      - torchrl
      - ray[rllib]

More info

I'm using the pytorch-lightning + hydra template for my Custom RL project. For the sake of reproducing it, I am simplifying it and adding all the base code into one file to make it runnable. Model basically runs super fast but if i enable checkpointing it is super slow. almost 4x time. There must be something wrong in my setting or the way I am doing this, and after wracking my brain for the whole day and running 30 experiments with different trial and error combinations I am lost.

The loggers are set appropriately to show me env_step, episode level metrics based on on_step and on_epoch parameters. Everything works perfect except the checkpointing and early stopping too. Never triggers.

NEEDED: To figure out how to make this work with just the save_last checkpoint flag and create the model checkpoint directory and save the last one run instead of checking each step, delaying it and then picking the best checkpoint. Since it is RL i don't expect reward monitoring to show any degradation after convergence

Context

There is a caveat here.

  1. I only have a train_step and no other hooks except on_train_start. When i start my trainer and set max_epochs=1000, that also means my trainer/global_step will go till 1000 and train_step is called 1000 times. on_train_star is just called once. So each trainstep is one epoch
  2. Each train_step = one env_step and also a batch sampled(size 200) from the buffer for loss calculation. if my episode terminates after 100 steps, then 1000 epochs = 1000 steps = 1000 buffer updates = 10 episodes = 1000 batches sampled for training.

Basically I do not see how to get it to trigger and save the checkpoint without the flag save_on_train_epoch_end and how to get it to be fast and not check every epoch which in my case is every time step because i need to run this for 100000 epochs/steps. If atleast i can save a checkpoint every 10000 steps even if i am not monitoring and getting the best model with the "moving avg reward across 10 episodes", that is fine cause eventually it converges and there's not much difference

BEST CASE : If i can get it to create a checkpoint directory and just save the last epoch without having to use the epoch end flag and slow down the whole training and experimentation.

This is the callback config I use.

defaults:
  - model_checkpoint
  - early_stopping
  - model_summary
  - rich_progress_bar
  - _self_

model_checkpoint:
  dirpath: ${paths.output_dir}/checkpoints
  filename: "epoch_{epoch:03d}"
  monitor: "train/Moving_avg_5_ep_reward"
  mode: "max"
  save_last: True
  every_n_epochs: 100
  auto_insert_metric_name: False
  save_on_train_epoch_end: True

early_stopping:
  monitor: "total_reward"
  min_delta: 10
  patience: 3
  mode: "max"

model_summary:
  max_depth: -1

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions