-
-
Notifications
You must be signed in to change notification settings - Fork 292
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1463 from lrzpellegrini/wandb_core_fixes
Various fixes and improvements
- Loading branch information
Showing
9 changed files
with
143 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,21 +4,21 @@ | |
# See the accompanying LICENSE file for terms. # | ||
# # | ||
# Date: 25-11-2020 # | ||
# Author(s): Diganta Misra, Andrea Cossu # | ||
# Author(s): Diganta Misra, Andrea Cossu, Lorenzo Pellegrini # | ||
# E-mail: [email protected] # | ||
# Website: www.continualai.org # | ||
################################################################################ | ||
""" This module handles all the functionalities related to the logging of | ||
Avalanche experiments using Weights & Biases. """ | ||
|
||
from typing import Union, List, TYPE_CHECKING | ||
import re | ||
from typing import Optional, Union, List, TYPE_CHECKING | ||
from pathlib import Path | ||
import os | ||
import errno | ||
import warnings | ||
|
||
import numpy as np | ||
from numpy import array | ||
import torch | ||
from torch import Tensor | ||
|
||
from PIL.Image import Image | ||
|
@@ -37,6 +37,12 @@ | |
from avalanche.training.templates import SupervisedTemplate | ||
|
||
|
||
CHECKPOINT_METRIC_NAME = re.compile( | ||
r"^WeightCheckpoint\/(?P<phase_name>\S+)_phase\/(?P<stream_name>\S+)_" | ||
r"stream(\/Task(?P<task_id>\d+))?\/Exp(?P<experience_id>\d+)$" | ||
) | ||
|
||
|
||
class WandBLogger(BaseLogger, SupervisedPlugin): | ||
"""Weights and Biases logger. | ||
|
@@ -60,18 +66,21 @@ def __init__( | |
run_name: str = "Test", | ||
log_artifacts: bool = False, | ||
path: Union[str, Path] = "Checkpoints", | ||
uri: str = None, | ||
uri: Optional[str] = None, | ||
sync_tfboard: bool = False, | ||
save_code: bool = True, | ||
config: object = None, | ||
dir: Union[str, Path] = None, | ||
params: dict = None, | ||
config: Optional[object] = None, | ||
dir: Optional[Union[str, Path]] = None, | ||
params: Optional[dict] = None, | ||
): | ||
"""Creates an instance of the `WandBLogger`. | ||
:param project_name: Name of the W&B project. | ||
:param run_name: Name of the W&B run. | ||
:param log_artifacts: Option to log model weights as W&B Artifacts. | ||
Note that, in order for model weights to be logged, the | ||
:class:`WeightCheckpoint` metric must be added to the | ||
evaluation plugin. | ||
:param path: Path to locally save the model checkpoints. | ||
:param uri: URI identifier for external storage buckets (GCS, S3). | ||
:param sync_tfboard: Syncs TensorBoard to the W&B dashboard UI. | ||
|
@@ -102,6 +111,8 @@ def __init__( | |
def import_wandb(self): | ||
try: | ||
import wandb | ||
|
||
assert hasattr(wandb, "__version__") | ||
except ImportError: | ||
raise ImportError('Please run "pip install wandb" to install wandb') | ||
self.wandb = wandb | ||
|
@@ -140,7 +151,7 @@ def after_training_exp( | |
self, | ||
strategy: "SupervisedTemplate", | ||
metric_values: List["MetricValue"], | ||
**kwargs | ||
**kwargs, | ||
): | ||
for val in metric_values: | ||
self.log_metrics([val]) | ||
|
@@ -151,6 +162,11 @@ def after_training_exp( | |
def log_single_metric(self, name, value, x_plot): | ||
self.step = x_plot | ||
|
||
if name.startswith("WeightCheckpoint"): | ||
if self.log_artifacts: | ||
self._log_checkpoint(name, value, x_plot) | ||
return | ||
|
||
if isinstance(value, AlternativeValues): | ||
value = value.best_supported_value( | ||
Image, | ||
|
@@ -192,26 +208,46 @@ def log_single_metric(self, name, value, x_plot): | |
elif isinstance(value, TensorImage): | ||
self.wandb.log({name: self.wandb.Image(array(value))}, step=self.step) | ||
|
||
elif name.startswith("WeightCheckpoint"): | ||
if self.log_artifacts: | ||
cwd = os.getcwd() | ||
ckpt = os.path.join(cwd, self.path) | ||
try: | ||
os.makedirs(ckpt) | ||
except OSError as e: | ||
if e.errno != errno.EEXIST: | ||
raise | ||
suffix = ".pth" | ||
dir_name = os.path.join(ckpt, name + suffix) | ||
artifact_name = os.path.join("Models", name + suffix) | ||
if isinstance(value, Tensor): | ||
torch.save(value, dir_name) | ||
name = os.path.splittext(self.checkpoint) | ||
artifact = self.wandb.Artifact(name, type="model") | ||
artifact.add_file(dir_name, name=artifact_name) | ||
self.wandb.run.log_artifact(artifact) | ||
if self.uri is not None: | ||
artifact.add_reference(self.uri, name=artifact_name) | ||
def _log_checkpoint(self, name, value, x_plot): | ||
assert self.wandb is not None | ||
|
||
# Example: 'WeightCheckpoint/train_phase/train_stream/Task000/Exp000' | ||
name_match = CHECKPOINT_METRIC_NAME.match(name) | ||
if name_match is None: | ||
warnings.warn(f"Checkpoint metric has unsupported name {name}.") | ||
return | ||
# phase_name: str = name_match['phase_name'] | ||
# stream_name: str = name_match['stream_name'] | ||
task_id: Optional[int] = ( | ||
int(name_match["task_id"]) if name_match["task_id"] is not None else None | ||
) | ||
experience_id: int = int(name_match["experience_id"]) | ||
assert experience_id >= 0 | ||
|
||
cwd = Path.cwd() | ||
checkpoint_directory = cwd / self.path | ||
checkpoint_directory.mkdir(parents=True, exist_ok=True) | ||
|
||
checkpoint_name = "Model_{}".format(experience_id) | ||
checkpoint_file_name = checkpoint_name + ".pth" | ||
checkpoint_path = checkpoint_directory / checkpoint_file_name | ||
artifact_name = "Models/" + checkpoint_file_name | ||
|
||
# Write the checkpoint blob | ||
with open(checkpoint_path, "wb") as f: | ||
f.write(value) | ||
|
||
metadata = { | ||
"experience": experience_id, | ||
"x_step": x_plot, | ||
**({"task_id": task_id} if task_id is not None else {}), | ||
} | ||
|
||
artifact = self.wandb.Artifact(checkpoint_name, type="model", metadata=metadata) | ||
artifact.add_file(str(checkpoint_path), name=artifact_name) | ||
self.wandb.run.log_artifact(artifact) | ||
if self.uri is not None: | ||
artifact.add_reference(self.uri, name=artifact_name) | ||
|
||
def __getstate__(self): | ||
state = self.__dict__.copy() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters