Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickling error when trying to save checkpoints with custom checkpointIO #11955

Open
jdnurme opened this issue Jan 24, 2025 · 5 comments
Open
Labels
bug Something isn't working

Comments

@jdnurme
Copy link

jdnurme commented Jan 24, 2025

Describe the bug

When providing a custom checkpoint_io to my strategy during NeMo training, the torch.save call fails with a pickling error.

i.pretrain/0 [default5]:[rank5]:   File "/opt/NeMo/nemo/lightning/pytorch/strategies/megatron_strategy.py", line 664, in save_checkpoint
i.pretrain/0 [default5]:[rank5]:     self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
i.pretrain/0 [default5]:[rank5]:   File "/root/.nemo_run/experiments/jd_pretraining/jd_pretraining_1737756914/__main__.py", line 32, in save_checkpoint
i.pretrain/0 [default5]:[rank5]:     torch.save(checkpoint, p)
i.pretrain/0 [default5]:[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 632, in save
i.pretrain/0 [default5]:[rank5]:     _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
i.pretrain/0 [default5]:[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 844, in _save
i.pretrain/0 [default5]:[rank5]:     pickler.dump(obj)
i.pretrain/0 [default5]:[rank5]: AttributeError: Can't pickle local object 'apply_swiglu_sharded_factory.<locals>.sh_ten_build_fn'

I'm utilizing a minimal custom Lightning CheckpointIO implementation that has been wrapped with the @run.autoconvert decorator. The goal is to eventually augment this implementation to save checkpoints to a remote datastore, however this code simply saves to a separate specified disc location using torch.save().

Training succeeds, but during checkpoint save the torch.save call fails with a pickling error.

Steps/Code to reproduce bug

run the following code with python filename.py

import nemo_run as run
from nemo.collections import llm
import torch

from lightning.pytorch.plugins.io import CheckpointIO
from typing import Any, Dict, Optional, Union
from pathlib import Path
from nemo.collections import llm
from nemo import lightning as nl
from megatron.core.distributed import DistributedDataParallelConfig

class CustomIO(CheckpointIO):
    """A Custom Checkpoint Manager"""

    def __init__(
        self,
    ):
        super().__init__()

    def save_checkpoint(
        self,
        checkpoint: Dict[str, Any],
        path: Union[str, Path],
        storage_options: Optional[Any] = None,
    ) -> None:
        import uuid
        u = uuid.uuid4()
        p = "/checkpoints/uuid/" + str(u)
        #torch.save(checkpoint, path)
        torch.save(checkpoint, p)

    def load_checkpoint(
        self,
        path: Union[str, Path],
        map_location: Optional[Any] = None,
    ) -> Dict[str, Any]:
        pass

    def remove_checkpoint(
        self,
        path: Union[str, Path],
    ) -> None:
        pass 

    def teardown(self, ) -> None:
        pass


def configure_recipe(nodes: int = 1, gpus_per_node: int = 8):
    recipe = llm.llama3_8b.pretrain_recipe(
        dir="/checkpoints/triall_1/", # Path to store checkpoints
        name="llama3_pretraining",
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
    )
    #import pdb; pdb.set_trace()
    recipe.trainer.max_steps = 1
    custom_ckpt = get_ckpt_io()
    strategy = run.Config(
      nl.MegatronStrategy,
      tensor_model_parallel_size=2,
      pipeline_model_parallel_size=1,
      pipeline_dtype=None,
      virtual_pipeline_model_parallel_size=None,
      context_parallel_size=1,
      sequence_parallel=False,
      gradient_as_bucket_view=True,
      checkpoint_io=custom_ckpt,
      ckpt_async_save=False,
      ckpt_parallel_load=True,
      ddp=run.Config(
          DistributedDataParallelConfig,
          check_for_nan_in_grad=True,
          grad_reduce_in_fp32=True,
          overlap_grad_reduce=True,
          overlap_param_gather=True,
          average_in_collective=True,
      )
    )
    recipe.trainer.strategy = strategy

    recipe.trainer.val_check_interval = 100
    return recipe

@run.autoconvert
def get_ckpt_io():
    return CustomIO()

def local_executor_torchrun(nodes: int = 1, devices: int = 8) -> run.LocalExecutor:
    # Env vars for jobs are configured here
    env_vars = {
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
    }

    executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)

    return executor

def run_pretraining():
    recipe = configure_recipe()
    executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

    run.run(recipe, executor=executor, name="jd_pretraining")

# This condition is necessary for the script to be compatible with Python's multiprocessing module.
if __name__ == "__main__":
    run_pretraining()

Expected behavior

I would expect this code to execute fully, but with checkpoints saved to the custom location specified in the CheckpointIO implementation.

Environment overview (please complete the following information)

  • Environment location: NeMo docker container running on GCS GPU instance.
  • Method of NeMo install: Docker container with NeMo pre-installed
  • Docker command used: `sudo docker pull nvcr.io/nvidia/nemo:dev

Environment details

If NVIDIA docker image is used you don't need to specify these.

  • This was used

Additional context

If an alternate methodology is better suited for custom checkpointing with NeMo 2, please advise. The hope is that existing Lightning CheckpointIO integrations would be able to plug into NeMo directly.

@jdnurme jdnurme added the bug Something isn't working label Jan 24, 2025
@jdnurme jdnurme changed the title Pickling error when trying to save checkpoints to google bucket Pickling error when trying to save checkpoints with custom checkpointIO Jan 25, 2025
@jdnurme
Copy link
Author

jdnurme commented Feb 6, 2025

Any help with this issue would be greatly appreciated. It should be very easy to recreate with the given script.

@ananthsub
Copy link
Collaborator

ananthsub commented Feb 11, 2025

@jdnurme with the megatron strategy, distributed checkpointing via megatron-core is used as described here: https://docs.nvidia.com/nemo-framework/user-guide/24.09/nemotoolkit/checkpoints/dist_ckpt.html

the pickling errors arise from these functions being unable to be serialized by torch.save : https://github.com/NVIDIA/Megatron-LM/blob/f2f81012239d7b98a76580e147da336cf3c94e93/megatron/core/transformer/mlp.py#L165-L253

The goal is to eventually augment this implementation to save checkpoints to a remote datastore

since megatron-core dist checkpointing is built on top of PyTorch's distributed checkpoint, is your remote storage compatible with the StorageWriter and StorageReader abstractions laid out here? https://pytorch.org/docs/stable/distributed.checkpoint.html

that may be an an easier approach to discuss integration vs. needing to handle all of the ShardedTensor logic introduced for distributed checkpointing

@jdnurme
Copy link
Author

jdnurme commented Feb 11, 2025

@ananthsub thank you for the reply. Yes, the remote checkpoint code we intend to integrate with does support the StorageReader/StorageWriter interface. Can you elaborate on how we might be able to pass that checkpoint mechanism into a nemo pretrain recipe like the one above?

@ananthsub
Copy link
Collaborator

@jdnurme we are working on integrating Nvidia's multi-storage client into the checkpointing flow to support object stores like GCS: https://github.com/NVIDIA/multi-storage-client - would this meet your needs if it's available?

@jdnurme
Copy link
Author

jdnurme commented Feb 13, 2025

Yes, an integration like that would be super helpful. @ananthsub do you have a sense of when this feature will be ready?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants