Skip to content

Model is not completely moved to a different device when using .cuda() or .cpu() #599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 of 2 tasks
giacomoguiduzzi opened this issue Mar 5, 2025 · 3 comments
Open
1 of 2 tasks
Assignees
Labels
bug Something isn't working

Comments

@giacomoguiduzzi
Copy link
Contributor

giacomoguiduzzi commented Mar 5, 2025

Hi Wenjie,

I was debugging an issue with StemGNN allocating too much VRAM on the backward() step, causing an OOM Exception. To handle this I wanted to finish the training on CPU, as it only happens with some model configurations. When doing so, I couldn't run the training as something was left on the CUDA device. Debugging the training function I discovered that the optimizer is not moved as it is not part of the modules.

1. System Info

  • PyPOTS version 0.8.1
  • CUDA 12.4
  • Intel Xeon Gold 6338
  • NVIDIA A40

2. Information

  • The official example scripts
  • My own created scripts

I solved the issue with the following:

imputer.device = torch.device("cpu")
# move model to CPU
imputer.model = imputer.model.cpu()
_move_optimizer_state_to_device(imputer.optimizer.torch_optimizer, 'cpu')
# move data to CPU
train_set = train_set.to(imputer.device)

imputer.fit(
    train_set={"X": train_set},
)

The _move_optimizer_state_to_device is as follows:

def _move_optimizer_state_to_device(optimizer, device: Union[torch.device, int, str]):
    """Recursively moves optimizer state tensors to the specified device."""
    for param_group in optimizer.state_dict()["param_groups"]:
        for param_id in param_group["params"]:
            if param_id in optimizer.state.keys():
                param_state = optimizer.state[param_id]
            else:
                param_state = optimizer.state
            # Handle momentum_buffer outside the param loop, as a special case
            if "momentum_buffer" in list(
                {key for param in param_state.values() for key in param.keys()}
            ):
                for param_state_dict in param_state.values():
                    for param_name, param in param_state_dict.items():
                        if param_name == "momentum_buffer":
                            if isinstance(param, tuple):
                                param_state_dict[param_name] = tuple(
                                    (
                                        item.to(device)
                                        if isinstance(item, torch.Tensor)
                                        else item
                                    )
                                    for item in param
                                )
                            elif isinstance(param, torch.Tensor):
                                param_state_dict[param_name] = param.to(device)

            # Iterate over the other keys in param_state
            for param_state_dict in param_state.values():
                for param_name, param in param_state_dict.items():
                    if param_name == "step":
                        continue

                    if param_name != "momentum_buffer" and isinstance(
                        param, torch.Tensor
                    ):
                        param_state_dict[param_name] = param.to(device)

I wrote the function looking at the Adam wrapper I found in PyPOTS and its state_dict, so I'm not sure it works in every case. I am using PyPOTS version 0.8.1.

3. Reproduction

  1. Instantiate a StemGNN model on cuda:0;
  2. Move it to CPU;
  3. Run fit().

4. Expected behavior

The model correctly moves all of its parts on CPU and runs the .fit() function correctly, without raising Exceptions.


I wanted to know what you think about it and ask you if I could test this on the latest PyPOTS version to create a PR, if you like this approach of course. In case you'd appreciate it, I wanted to ask you where to include this function. I noticed that .cuda() or .cpu() are functions from the PyTorch nn.Module class, so I think it would be cool to overload them with a PyPOTS version that calls PyTorch's one and _move_optimizer_state_to_device. Also, I tested this only on StemGNN for now so it's better to test it with other optimizers or models if we want to include it.

Looking forward to your kind response.

Best Regards,
Giacomo Guiduzzi

@giacomoguiduzzi giacomoguiduzzi added the bug Something isn't working label Mar 5, 2025
@WenjieDu
Copy link
Owner

WenjieDu commented Mar 10, 2025

Hey Giacomo, thanks for diving into this! I agree with your analysis, and I'd like to share some details to make it clear. The optimizer is passed into the model as initialized so users can change the concrete implementation as they wish, including hyperparams (e.g. learning rate) and optimization process (e.g. algo itself and lr scheduler). I didn't put optimizer as an argument of parent classes (BaseNNModel and BaseNNImputer etc.), because some models need more than one optimizer during training, such as GANs (e.g. USGAN and CRLI in PyPOTS). Hence, I leave it for each concrete model to handle. To sum up, if we'd like to put a func like _move_optimizer_state_to_device you provide above into to() in BaseModel, we have to consider the situation that model has multiple optimizers.

What do you think? To fix this bug, we may need to do some refactoring to the current framework ;-)

@giacomoguiduzzi
Copy link
Contributor Author

Hi Wenjie,

I understand, thanks for clarifying. About having multiple optimizers I guess it depends on how those optimizers are managed. If instead of an instance of a torch.optim object we have a list, the solution is trivial (adding another for loop). To better handle different objects and classes, it could be useful to edit _move_optimizer_state_to_device to exploit some more generic function like dir to traverse the whole object, find tensor and move them.
This might need more work than I thought 🤔
I'll look into this asap, together with the other issue!

Best,
Giacomo

@WenjieDu
Copy link
Owner

Thanks, dude @giacomoguiduzzi. Left you a message in #575 ;-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants