From 7a92ffcfc28fcf4d12179b3d7dec4f2e4d6b8ed4 Mon Sep 17 00:00:00 2001 From: AlbinSou Date: Sat, 7 Oct 2023 16:22:58 +0200 Subject: [PATCH] added last subexp checking to at_task_boundaries util --- avalanche/training/utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/avalanche/training/utils.py b/avalanche/training/utils.py index b51a60f7f..50c121ab3 100644 --- a/avalanche/training/utils.py +++ b/avalanche/training/utils.py @@ -15,15 +15,15 @@ """ from collections import defaultdict -from typing import Dict, NamedTuple, List, Optional, Tuple, Callable, Union +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union import torch from torch import Tensor -from torch.nn import Module, Linear -from torch.utils.data import Dataset, DataLoader +from torch.nn import Linear, Module +from torch.utils.data import DataLoader, Dataset -from avalanche.models.batch_renorm import BatchRenorm2D from avalanche.benchmarks import OnlineCLExperience +from avalanche.models.batch_renorm import BatchRenorm2D def at_task_boundary(training_experience) -> bool: @@ -45,7 +45,10 @@ def at_task_boundary(training_experience) -> bool: if isinstance(training_experience, OnlineCLExperience): if training_experience.access_task_boundaries: - if training_experience.is_first_subexp: + if ( + training_experience.is_first_subexp + or training_experience.is_last_subexp + ): return True else: return True @@ -222,7 +225,7 @@ def replace_bn_with_brn( ): for attr_str in dir(m): target_attr = getattr(m, attr_str) - if type(target_attr) == torch.nn.BatchNorm2d: + if isinstance(target_attr, torch.nn.BatchNorm2d): # print('replaced: ', name, attr_str) setattr( m, @@ -253,7 +256,7 @@ def change_brn_pars( ): for attr_str in dir(m): target_attr = getattr(m, attr_str) - if type(target_attr) == BatchRenorm2D: + if isinstance(target_attr, BatchRenorm2D): target_attr.momentum = torch.tensor((momentum), requires_grad=False) target_attr.r_max = torch.tensor(r_max, requires_grad=False) target_attr.d_max = torch.tensor(d_max, requires_grad=False)