From 167ebf364e630e9c471507d45b520c3a5c763fd0 Mon Sep 17 00:00:00 2001 From: AndreaCossu Date: Tue, 20 Feb 2024 09:37:28 +0100 Subject: [PATCH] Fixed lamaml buffer bug --- avalanche/training/supervised/lamaml_v2.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/avalanche/training/supervised/lamaml_v2.py b/avalanche/training/supervised/lamaml_v2.py index 36905cb80..99256feb1 100644 --- a/avalanche/training/supervised/lamaml_v2.py +++ b/avalanche/training/supervised/lamaml_v2.py @@ -99,7 +99,6 @@ def __init__( buffer_mb_size=buffer_mb_size, device=device, ) - self.model.apply(init_kaiming_normal) def _before_training_exp(self, **kwargs): @@ -305,16 +304,15 @@ def __len__(self): def get_buffer_batch(self): rnd_ind = torch.randperm(len(self))[: self.buffer_mb_size] - buff_x = torch.cat( - [self.storage_policy.buffer[i][0].unsqueeze(0) for i in rnd_ind] - ).to(self.device) - buff_y = torch.LongTensor( - [self.storage_policy.buffer[i][1] for i in rnd_ind] - ).to(self.device) - buff_t = torch.LongTensor( - [self.storage_policy.buffer[i][2] for i in rnd_ind] - ).to(self.device) - + buff = self.storage_policy.buffer.subset(rnd_ind) + buff_x, buff_y, buff_t = [], [], [] + for bx, by, bt in buff: + buff_x.append(bx) + buff_y.append(by) + buff_t.append(bt) + buff_x = torch.stack(buff_x, dim=0).to(self.device) + buff_y = torch.tensor(buff_y).to(self.device).long() + buff_t = torch.tensor(buff_t).to(self.device).long() return buff_x, buff_y, buff_t