Skip to content

Commit

Permalink
Merge pull request #1463 from lrzpellegrini/wandb_core_fixes
Browse files Browse the repository at this point in the history
Various fixes and improvements
  • Loading branch information
AntonioCarta authored Jul 19, 2023
2 parents 435b40d + abde4c2 commit c1f34d1
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 48 deletions.
14 changes: 14 additions & 0 deletions avalanche/benchmarks/classic/core50.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,20 @@ def CORe50(
eval_transform=eval_transform,
)

if scenario == "nc":
n_classes_per_exp = []
classes_order = []
for exp in benchmark_obj.train_stream:
exp_dataset = exp.dataset
unique_targets = list(
sorted(set(int(x) for x in exp_dataset.targets)) # type: ignore
)
n_classes_per_exp.append(len(unique_targets))
classes_order.extend(unique_targets)
setattr(benchmark_obj, "n_classes_per_exp", n_classes_per_exp)
setattr(benchmark_obj, "classes_order", classes_order)
setattr(benchmark_obj, "n_classes", 50 if object_lvl else 10)

return benchmark_obj


Expand Down
23 changes: 16 additions & 7 deletions avalanche/evaluation/metrics/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
################################################################################

import copy
from typing import TYPE_CHECKING
import io
from typing import TYPE_CHECKING, Optional

from torch import Tensor
import torch

from avalanche.evaluation import PluginMetric
from avalanche.evaluation.metric_results import MetricValue, MetricResult
Expand Down Expand Up @@ -46,9 +48,9 @@ def __init__(self):
retrieved using the `result` method.
"""
super().__init__()
self.weights = None
self.weights: Optional[bytes] = None

def update(self, weights) -> Tensor:
def update(self, weights: bytes):
"""
Update the weight checkpoint at the current experience.
Expand All @@ -57,7 +59,7 @@ def update(self, weights) -> Tensor:
"""
self.weights = weights

def result(self) -> Tensor:
def result(self) -> Optional[bytes]:
"""
Retrieves the weight checkpoint at the current experience.
Expand All @@ -75,6 +77,9 @@ def reset(self) -> None:

def _package_result(self, strategy) -> "MetricResult":
weights = self.result()
if weights is None:
return None

metric_name = get_metric_name(
self, strategy, add_experience=True, add_task=False
)
Expand All @@ -83,9 +88,13 @@ def _package_result(self, strategy) -> "MetricResult":
]

def after_training_exp(self, strategy: "SupervisedTemplate") -> "MetricResult":
model_params = copy.deepcopy(strategy.model.parameters())
self.update(model_params)
return None
buff = io.BytesIO()
model_params = copy.deepcopy(strategy.model).to("cpu")
torch.save(model_params, buff)
buff.seek(0)
self.update(buff.read())

return self._package_result(strategy)

def __str__(self):
return "WeightCheckpoint"
Expand Down
6 changes: 4 additions & 2 deletions avalanche/logging/text_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# Website: avalanche.continualai.org #
################################################################################
import datetime
import os.path
import sys
import warnings
from typing import List, TYPE_CHECKING, Tuple, Type, Optional, TextIO
Expand All @@ -24,7 +23,10 @@
if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate

UNSUPPORTED_TYPES: Tuple[Type] = (TensorImage,)
UNSUPPORTED_TYPES: Tuple[Type, ...] = (
TensorImage,
bytes,
)


class TextLogger(BaseLogger, SupervisedPlugin):
Expand Down
94 changes: 65 additions & 29 deletions avalanche/logging/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
# See the accompanying LICENSE file for terms. #
# #
# Date: 25-11-2020 #
# Author(s): Diganta Misra, Andrea Cossu #
# Author(s): Diganta Misra, Andrea Cossu, Lorenzo Pellegrini #
# E-mail: [email protected] #
# Website: www.continualai.org #
################################################################################
""" This module handles all the functionalities related to the logging of
Avalanche experiments using Weights & Biases. """

from typing import Union, List, TYPE_CHECKING
import re
from typing import Optional, Union, List, TYPE_CHECKING
from pathlib import Path
import os
import errno
import warnings

import numpy as np
from numpy import array
import torch
from torch import Tensor

from PIL.Image import Image
Expand All @@ -37,6 +37,12 @@
from avalanche.training.templates import SupervisedTemplate


CHECKPOINT_METRIC_NAME = re.compile(
r"^WeightCheckpoint\/(?P<phase_name>\S+)_phase\/(?P<stream_name>\S+)_"
r"stream(\/Task(?P<task_id>\d+))?\/Exp(?P<experience_id>\d+)$"
)


class WandBLogger(BaseLogger, SupervisedPlugin):
"""Weights and Biases logger.
Expand All @@ -60,18 +66,21 @@ def __init__(
run_name: str = "Test",
log_artifacts: bool = False,
path: Union[str, Path] = "Checkpoints",
uri: str = None,
uri: Optional[str] = None,
sync_tfboard: bool = False,
save_code: bool = True,
config: object = None,
dir: Union[str, Path] = None,
params: dict = None,
config: Optional[object] = None,
dir: Optional[Union[str, Path]] = None,
params: Optional[dict] = None,
):
"""Creates an instance of the `WandBLogger`.
:param project_name: Name of the W&B project.
:param run_name: Name of the W&B run.
:param log_artifacts: Option to log model weights as W&B Artifacts.
Note that, in order for model weights to be logged, the
:class:`WeightCheckpoint` metric must be added to the
evaluation plugin.
:param path: Path to locally save the model checkpoints.
:param uri: URI identifier for external storage buckets (GCS, S3).
:param sync_tfboard: Syncs TensorBoard to the W&B dashboard UI.
Expand Down Expand Up @@ -102,6 +111,8 @@ def __init__(
def import_wandb(self):
try:
import wandb

assert hasattr(wandb, "__version__")
except ImportError:
raise ImportError('Please run "pip install wandb" to install wandb')
self.wandb = wandb
Expand Down Expand Up @@ -140,7 +151,7 @@ def after_training_exp(
self,
strategy: "SupervisedTemplate",
metric_values: List["MetricValue"],
**kwargs
**kwargs,
):
for val in metric_values:
self.log_metrics([val])
Expand All @@ -151,6 +162,11 @@ def after_training_exp(
def log_single_metric(self, name, value, x_plot):
self.step = x_plot

if name.startswith("WeightCheckpoint"):
if self.log_artifacts:
self._log_checkpoint(name, value, x_plot)
return

if isinstance(value, AlternativeValues):
value = value.best_supported_value(
Image,
Expand Down Expand Up @@ -192,26 +208,46 @@ def log_single_metric(self, name, value, x_plot):
elif isinstance(value, TensorImage):
self.wandb.log({name: self.wandb.Image(array(value))}, step=self.step)

elif name.startswith("WeightCheckpoint"):
if self.log_artifacts:
cwd = os.getcwd()
ckpt = os.path.join(cwd, self.path)
try:
os.makedirs(ckpt)
except OSError as e:
if e.errno != errno.EEXIST:
raise
suffix = ".pth"
dir_name = os.path.join(ckpt, name + suffix)
artifact_name = os.path.join("Models", name + suffix)
if isinstance(value, Tensor):
torch.save(value, dir_name)
name = os.path.splittext(self.checkpoint)
artifact = self.wandb.Artifact(name, type="model")
artifact.add_file(dir_name, name=artifact_name)
self.wandb.run.log_artifact(artifact)
if self.uri is not None:
artifact.add_reference(self.uri, name=artifact_name)
def _log_checkpoint(self, name, value, x_plot):
assert self.wandb is not None

# Example: 'WeightCheckpoint/train_phase/train_stream/Task000/Exp000'
name_match = CHECKPOINT_METRIC_NAME.match(name)
if name_match is None:
warnings.warn(f"Checkpoint metric has unsupported name {name}.")
return
# phase_name: str = name_match['phase_name']
# stream_name: str = name_match['stream_name']
task_id: Optional[int] = (
int(name_match["task_id"]) if name_match["task_id"] is not None else None
)
experience_id: int = int(name_match["experience_id"])
assert experience_id >= 0

cwd = Path.cwd()
checkpoint_directory = cwd / self.path
checkpoint_directory.mkdir(parents=True, exist_ok=True)

checkpoint_name = "Model_{}".format(experience_id)
checkpoint_file_name = checkpoint_name + ".pth"
checkpoint_path = checkpoint_directory / checkpoint_file_name
artifact_name = "Models/" + checkpoint_file_name

# Write the checkpoint blob
with open(checkpoint_path, "wb") as f:
f.write(value)

metadata = {
"experience": experience_id,
"x_step": x_plot,
**({"task_id": task_id} if task_id is not None else {}),
}

artifact = self.wandb.Artifact(checkpoint_name, type="model", metadata=metadata)
artifact.add_file(str(checkpoint_path), name=artifact_name)
self.wandb.run.log_artifact(artifact)
if self.uri is not None:
artifact.add_reference(self.uri, name=artifact_name)

def __getstate__(self):
state = self.__dict__.copy()
Expand Down
10 changes: 8 additions & 2 deletions avalanche/training/plugins/ewc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def after_training_exp(self, strategy, **kwargs):
strategy.experience.dataset,
strategy.device,
strategy.train_mb_size,
num_workers=kwargs.get("num_workers", 0),
)
self.update_importances(importances, exp_counter)
self.saved_params[exp_counter] = copy_params_dict(strategy.model)
Expand All @@ -129,7 +130,7 @@ def after_training_exp(self, strategy, **kwargs):
del self.saved_params[exp_counter - 1]

def compute_importances(
self, model, criterion, optimizer, dataset, device, batch_size
self, model, criterion, optimizer, dataset, device, batch_size, num_workers=0
) -> Dict[str, ParamData]:
"""
Compute EWC importance matrix for each parameter
Expand All @@ -153,7 +154,12 @@ def compute_importances(
# list of list
importances = zerolike_params_dict(model)
collate_fn = dataset.collate_fn if hasattr(dataset, "collate_fn") else None
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=num_workers,
)
for i, batch in enumerate(dataloader):
# get only input, target and task_id from the batch
x, y, task_labels = batch[0], batch[1], batch[-1]
Expand Down
4 changes: 2 additions & 2 deletions examples/multihead.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def main(args):

# train and test loop
for train_task in train_stream:
strategy.train(train_task)
strategy.eval(test_stream)
strategy.train(train_task, num_workers=4)
strategy.eval(test_stream, num_workers=4)


if __name__ == "__main__":
Expand Down
23 changes: 21 additions & 2 deletions examples/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from avalanche.benchmarks import nc_benchmark
from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location
from avalanche.evaluation.metrics.checkpoint import WeightCheckpoint
from avalanche.logging import InteractiveLogger, WandBLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import (
Expand Down Expand Up @@ -83,7 +84,11 @@ def main(args):

interactive_logger = InteractiveLogger()
wandb_logger = WandBLogger(
project_name=args.project, run_name=args.run, config=vars(args)
project_name=args.project,
run_name=args.run,
log_artifacts=args.artifacts,
path=args.path if args.path else None,
config=vars(args),
)

eval_plugin = EvaluationPlugin(
Expand Down Expand Up @@ -120,6 +125,7 @@ def main(args):
),
disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True),
MAC_metrics(minibatch=True, epoch=True, experience=True),
WeightCheckpoint(),
loggers=[interactive_logger, wandb_logger],
)

Expand Down Expand Up @@ -157,9 +163,22 @@ def main(args):
default=0,
help="Select zero-indexed cuda device. -1 to use CPU.",
)
parser.add_argument("--run", type=str, help="Provide a run name for WandB")
parser.add_argument(
"--project", type=str, help="Define the name of the WandB project"
)
parser.add_argument("--run", type=str, help="Provide a run name for WandB")
parser.add_argument(
"--artifacts",
default=False,
action="store_true",
help="Log Model Checkpoints as W&B Artifacts",
)
parser.add_argument(
"--path",
type=str,
default="Checkpoint",
help="Local path to save the model checkpoints",
)

args = parser.parse_args()
main(args)
5 changes: 5 additions & 0 deletions tests/test_core50.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def test_core50_nc_benchmark(self):
classes_in_test = benchmark_instance.classes_in_experience["test"][0]
self.assertSetEqual(set(range(50)), set(classes_in_test))

# Regression tests for issue #774
self.assertSequenceEqual([10] + ([5] * 8), benchmark_instance.n_classes_per_exp)
self.assertSetEqual(set(range(50)), set(benchmark_instance.classes_order))
self.assertEqual(50, len(benchmark_instance.classes_order))


if __name__ == "__main__":
unittest.main()
12 changes: 8 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import os
import copy
import tempfile

import unittest

Expand Down Expand Up @@ -650,10 +651,13 @@ def test_ncm_save_load(self):
),
}
)
torch.save(classifier.state_dict(), "ncm.pt")
del classifier
classifier = NCMClassifier()
check = torch.load("ncm.pt")

with tempfile.TemporaryFile() as tmpfile:
torch.save(classifier.state_dict(), tmpfile)
del classifier
classifier = NCMClassifier()
tmpfile.seek(0)
check = torch.load(tmpfile)
classifier.load_state_dict(check)
assert classifier.class_means.shape == (3, 5)
assert (classifier.class_means[0] == 0).all()
Expand Down

0 comments on commit c1f34d1

Please sign in to comment.