Skip to content

Commit

Permalink
Merge pull request #1601 from AndreaCossu/master
Browse files Browse the repository at this point in the history
LaMAMLv2 buffer now uses subset instead of indexing
  • Loading branch information
AntonioCarta authored Feb 20, 2024
2 parents 9bc65ef + 743bebf commit 2d289fa
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions avalanche/training/supervised/lamaml_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __init__(
buffer_mb_size=buffer_mb_size,
device=device,
)

self.model.apply(init_kaiming_normal)

def _before_training_exp(self, **kwargs):
Expand Down Expand Up @@ -305,16 +304,15 @@ def __len__(self):

def get_buffer_batch(self):
rnd_ind = torch.randperm(len(self))[: self.buffer_mb_size]
buff_x = torch.cat(
[self.storage_policy.buffer[i][0].unsqueeze(0) for i in rnd_ind]
).to(self.device)
buff_y = torch.LongTensor(
[self.storage_policy.buffer[i][1] for i in rnd_ind]
).to(self.device)
buff_t = torch.LongTensor(
[self.storage_policy.buffer[i][2] for i in rnd_ind]
).to(self.device)

buff = self.storage_policy.buffer.subset(rnd_ind)
buff_x, buff_y, buff_t = [], [], []
for bx, by, bt in buff:
buff_x.append(bx)
buff_y.append(by)
buff_t.append(bt)
buff_x = torch.stack(buff_x, dim=0).to(self.device)
buff_y = torch.tensor(buff_y).to(self.device).long()
buff_t = torch.tensor(buff_t).to(self.device).long()
return buff_x, buff_y, buff_t


Expand Down

0 comments on commit 2d289fa

Please sign in to comment.