Skip to content

Commit

Permalink
added last subexp checking to at_task_boundaries util
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbinSou committed Oct 7, 2023
1 parent a3ed06a commit 7a92ffc
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions avalanche/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""
from collections import defaultdict
from typing import Dict, NamedTuple, List, Optional, Tuple, Callable, Union
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module, Linear
from torch.utils.data import Dataset, DataLoader
from torch.nn import Linear, Module
from torch.utils.data import DataLoader, Dataset

from avalanche.models.batch_renorm import BatchRenorm2D
from avalanche.benchmarks import OnlineCLExperience
from avalanche.models.batch_renorm import BatchRenorm2D


def at_task_boundary(training_experience) -> bool:
Expand All @@ -45,7 +45,10 @@ def at_task_boundary(training_experience) -> bool:

if isinstance(training_experience, OnlineCLExperience):
if training_experience.access_task_boundaries:
if training_experience.is_first_subexp:
if (
training_experience.is_first_subexp
or training_experience.is_last_subexp
):
return True
else:
return True
Expand Down Expand Up @@ -222,7 +225,7 @@ def replace_bn_with_brn(
):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if type(target_attr) == torch.nn.BatchNorm2d:
if isinstance(target_attr, torch.nn.BatchNorm2d):
# print('replaced: ', name, attr_str)
setattr(
m,
Expand Down Expand Up @@ -253,7 +256,7 @@ def change_brn_pars(
):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if type(target_attr) == BatchRenorm2D:
if isinstance(target_attr, BatchRenorm2D):
target_attr.momentum = torch.tensor((momentum), requires_grad=False)
target_attr.r_max = torch.tensor(r_max, requires_grad=False)
target_attr.d_max = torch.tensor(d_max, requires_grad=False)
Expand Down

0 comments on commit 7a92ffc

Please sign in to comment.