Skip to content

Commit

Permalink
Merge branch 'feature/wandb_artifact_log_checkpoint' into 'main'
Browse files Browse the repository at this point in the history
Feature/wandb artifact log checkpoint

See merge request ADLR/megatron-lm!2575
  • Loading branch information
ko3n1g committed Jan 22, 2025
2 parents 66e5306 + e2a0e9e commit 27756f4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
12 changes: 12 additions & 0 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ..core.dist_checkpointing.serialization import \
get_default_save_sharded_strategy
from .one_logger_utils import on_save_checkpoint_start, on_save_checkpoint_success
from . import wandb_utils

# [ModelOpt]: Import
try:
Expand Down Expand Up @@ -514,6 +515,17 @@ def onelogger_finalize_fn():
else:
onelogger_finalize_fn()

# Additional callback for wandb (last rank)
if not torch.distributed.is_initialized() \
or is_last_rank():
def wandb_finalize_fn():
wandb_utils.on_save_checkpoint_success(checkpoint_name, get_checkpoint_tracker_filename(save_dir), save_dir, iteration)
if args.async_save:
assert async_save_request is not None
async_save_request.add_finalize_fn(wandb_finalize_fn)
else:
wandb_finalize_fn()

if args.async_save:
schedule_async_save(async_save_request)
print_rank_0(' scheduled an async checkpoint save at iteration {:7d} to {}' \
Expand Down
25 changes: 25 additions & 0 deletions megatron/training/wandb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from pathlib import Path

from megatron.training.global_vars import get_wandb_writer


def on_save_checkpoint_success(checkpoint_path: str, tracker_filename: str, save_dir: str, iteration: int) -> None:
"""Function to be called after checkpointing succeeds and checkpoint is persisted for logging it as an artifact in W&B
Args:
checkpoint_path (str): path of the saved checkpoint
tracker_filename (str): path of the tracker filename for the checkpoint iteration
save_dir (str): path of the root save folder for all checkpoints
iteration (int): iteration of the checkpoint
"""

wandb_writer = get_wandb_writer()

if wandb_writer:
metadata = {"iteration": iteration}
artifact = wandb_writer.Artifact(Path(save_dir).stem, type="model", metadata=metadata)
artifact.add_reference(f"file://{checkpoint_path}", checksum=False)
artifact.add_file(tracker_filename)
wandb_writer.run.log_artifact(artifact, aliases=[Path(checkpoint_path).stem])

0 comments on commit 27756f4

Please sign in to comment.