diff --git a/avalanche/training/supervised/lamaml_v2.py b/avalanche/training/supervised/lamaml_v2.py index 36905cb80..99256feb1 100644 --- a/avalanche/training/supervised/lamaml_v2.py +++ b/avalanche/training/supervised/lamaml_v2.py @@ -99,7 +99,6 @@ def __init__( buffer_mb_size=buffer_mb_size, device=device, ) - self.model.apply(init_kaiming_normal) def _before_training_exp(self, **kwargs): @@ -305,16 +304,15 @@ def __len__(self): def get_buffer_batch(self): rnd_ind = torch.randperm(len(self))[: self.buffer_mb_size] - buff_x = torch.cat( - [self.storage_policy.buffer[i][0].unsqueeze(0) for i in rnd_ind] - ).to(self.device) - buff_y = torch.LongTensor( - [self.storage_policy.buffer[i][1] for i in rnd_ind] - ).to(self.device) - buff_t = torch.LongTensor( - [self.storage_policy.buffer[i][2] for i in rnd_ind] - ).to(self.device) - + buff = self.storage_policy.buffer.subset(rnd_ind) + buff_x, buff_y, buff_t = [], [], [] + for bx, by, bt in buff: + buff_x.append(bx) + buff_y.append(by) + buff_t.append(bt) + buff_x = torch.stack(buff_x, dim=0).to(self.device) + buff_y = torch.tensor(buff_y).to(self.device).long() + buff_t = torch.tensor(buff_t).to(self.device).long() return buff_x, buff_y, buff_t