Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions detectron2/utils/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional
import torch
from fvcore.common.history_buffer import HistoryBuffer
import wandb

from detectron2.utils.file_io import PathManager

Expand Down Expand Up @@ -191,6 +192,62 @@ def close(self):
if "_writer" in self.__dict__:
self._writer.close()

class WandbWriter(EventWriter):
def __init__(self, project_name, run_name=None, window_size=20, **kwargs):
"""
Args:
project_name (str): The name of the W&B project.
run_name (str): The name of the W&B run.
window_size (int): The window size for smoothing metrics.
kwargs: Additional arguments for wandb.init().
"""
self._window_size = window_size
self._last_write = -1
wandb.init(project=project_name, name=run_name, **kwargs)

def write(self):
storage = get_event_storage()
new_last_write = self._last_write
metrics_dict = storage.latest_with_smoothing_hint(self._window_size).items()
wandb_metrics = {}
new_last_write = self._last_write
for k, (v, iter) in metrics_dict:
if iter > self._last_write:
wandb_metrics[k] = v
new_last_write = max(new_last_write, iter)
self._last_write = new_last_write

if len(storage._vis_data) >= 1:
# Create a list to store all images for this step
images_dict = {}

for img_name, img, step_num in storage._vis_data:
# Transpose from C,H,W to H,W,C
img = img.transpose(1, 2, 0)
# Add image to dictionary
images_dict[img_name] = wandb.Image(img)

# Log both metrics and all images for this step
log_dict = {
**wandb_metrics, # Unpack all metrics
**images_dict # Unpack all images
}
wandb.log(log_dict, step=iter)

# Storage stores all image data and rely on this writer to clear them.
# As a result it assumes only one writer will use its image data.
# An alternative design is to let storage store limited recent
# data (e.g. only the most recent image) that all writers can access.
# In that case a writer may not see all image data if its period is long.
storage.clear_images()
else:
wandb.log(wandb_metrics, step=new_last_write)

self._last_write = new_last_write

def close(self):
wandb.finish()


class CommonMetricPrinter(EventWriter):
"""
Expand Down