Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 112 additions & 15 deletions verl/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import dataclasses
import hashlib
import json
import os
from enum import Enum
Expand Down Expand Up @@ -63,22 +64,10 @@ def __init__(self, project_name, experiment_name, default_backend: str | list[st
self.logger = {}

if "tracking" in default_backend or "wandb" in default_backend:
import os

import wandb

settings = None
if config and config["trainer"].get("wandb_proxy", None):
settings = wandb.Settings(https_proxy=config["trainer"]["wandb_proxy"])
entity = os.environ.get("WANDB_ENTITY", None)
wandb.init(project=project_name, name=experiment_name, entity=entity, config=config, settings=settings)
self.logger["wandb"] = wandb
self.logger["wandb"] = _WandbLoggingAdapter(project_name, experiment_name, config)

if "trackio" in default_backend:
import trackio

trackio.init(project=project_name, name=experiment_name, config=config)
self.logger["trackio"] = trackio
self.logger["trackio"] = _TrackioLoggingAdapter(project=project_name, name=experiment_name, config=config)

if "mlflow" in default_backend:
import os
Expand All @@ -91,7 +80,10 @@ def __init__(self, project_name, experiment_name, default_backend: str | list[st
# Some cloud providers like Azure ML or Databricks automatically set MLFLOW_RUN_ID
# If set, attach to the existing run instead of creating a new one
run_id = os.environ.get("MLFLOW_RUN_ID")
if run_id is None:
run_id = _MlflowLoggingAdapter.get_run_id_by_name(project_name, experiment_name)
if run_id:
print(f"[MLflow] Resuming run with ID: {run_id}")
mlflow.start_run(run_id=run_id)
else:
# Project_name is actually experiment_name in MLFlow
Expand Down Expand Up @@ -162,7 +154,7 @@ def __init__(self, project_name, experiment_name, default_backend: str | list[st
def log(self, data, step, backend=None):
for default_backend, logger_instance in self.logger.items():
if backend is None or default_backend in backend:
logger_instance.log(data=data, step=step)
logger_instance.log(data, step=step)

def __del__(self):
if "wandb" in self.logger:
Expand Down Expand Up @@ -308,6 +300,104 @@ def sanitize_key(key):
results = {sanitize_key(k): v for k, v in data.items()}
mlflow.log_metrics(metrics=results, step=step)

@staticmethod
def get_run_id_by_name(experiment_name, run_name):
"""
Search for a run within a specific experiment by its name
and return the run_id.
"""
import mlflow

runs = mlflow.search_runs(
experiment_names=[experiment_name],
filter_string=f"tags.mlflow.runName = '{run_name}'",
max_results=1,
order_by=["attribute.start_time DESC"], # Get the most recent if duplicates exist
)

if runs.empty:
return None

return runs.iloc[0].run_id


class _WandbLoggingAdapter:
METRIC_STEP = "training/global_step"

"""Adapter to log metrics to Weights & Biases (wandb) so that one can log out-of-order."""

def __init__(self, project_name, experiment_name, config):
import os

import wandb

settings = None
if config and config["trainer"].get("wandb_proxy", None):
settings = wandb.Settings(https_proxy=config["trainer"]["wandb_proxy"])
entity = os.environ.get("WANDB_ENTITY", None)
wandb.init(
project=project_name,
name=experiment_name,
entity=entity,
config=config,
settings=settings,
id=self.hash_name(f"{project_name}_{experiment_name}"),
resume="allow",
)
wandb.define_metric(self.METRIC_STEP, hidden=True)
wandb.define_metric("*", step_metric=self.METRIC_STEP)
self.wandb = wandb

def log(self, data, step):
if step is not None:
self.wandb.log(data | {self.METRIC_STEP: step}, commit=True)
else:
self.wandb.log(data)

def finish(self):
self.wandb.finish()

@staticmethod
def hash_name(name: str) -> str:
"""Generate a short hash for a given name."""
return hashlib.md5(name.encode("utf-8")).hexdigest()[:16]


class _TrackioLoggingAdapter:
def __init__(self, project_name, experiment_name, config):
import trackio

trackio.init(project=project_name, name=experiment_name, config=config)
self.trackio = trackio

def log(self, data, step):
self.trackio.log(self._sanitize(data), step=step)

def finish(self):
self.trackio.finish()

@classmethod
def _sanitize(cls, obj: Any) -> Any:
"""Recursively sanitize the object to make it JSON serializable."""
import numpy as np

if isinstance(obj, dict):
return {k: cls._sanitize(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [cls._sanitize(i) for i in obj]
elif isinstance(obj, tuple):
return tuple(cls._sanitize(i) for i in obj)
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.bool_):
return bool(obj)
elif isinstance(obj, np.ndarray):
return [cls._sanitize(i) for i in obj]
else:
return obj


def _compute_mlflow_params_from_objects(params) -> dict[str, Any]:
if params is None:
Expand Down Expand Up @@ -356,6 +446,8 @@ def log(self, loggers, samples, step):
self.log_generations_to_swanlab(samples, step)
if "mlflow" in loggers:
self.log_generations_to_mlflow(samples, step)
if "trackio" in loggers:
self.log_generations_to_trackio(samples, step)

if "clearml" in loggers:
self.log_generations_to_clearml(samples, step)
Expand All @@ -375,6 +467,11 @@ def log_generations_to_wandb(self, samples, step):

self._log_generations_to_wandb(samples, step, wandb)

def log_generations_to_trackio(self, samples, step):
import trackio

self._log_generations_to_wandb(samples, step, trackio)

def _log_generations_to_wandb(self, samples, step, wandb):
"""Log samples to wandb as a table"""

Expand Down