Skip to content

Commit 2d289fa

Browse files
authored
Merge pull request #1601 from AndreaCossu/master
LaMAMLv2 buffer now uses subset instead of indexing
2 parents 9bc65ef + 743bebf commit 2d289fa

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

avalanche/training/supervised/lamaml_v2.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def __init__(
9999
buffer_mb_size=buffer_mb_size,
100100
device=device,
101101
)
102-
103102
self.model.apply(init_kaiming_normal)
104103

105104
def _before_training_exp(self, **kwargs):
@@ -305,16 +304,15 @@ def __len__(self):
305304

306305
def get_buffer_batch(self):
307306
rnd_ind = torch.randperm(len(self))[: self.buffer_mb_size]
308-
buff_x = torch.cat(
309-
[self.storage_policy.buffer[i][0].unsqueeze(0) for i in rnd_ind]
310-
).to(self.device)
311-
buff_y = torch.LongTensor(
312-
[self.storage_policy.buffer[i][1] for i in rnd_ind]
313-
).to(self.device)
314-
buff_t = torch.LongTensor(
315-
[self.storage_policy.buffer[i][2] for i in rnd_ind]
316-
).to(self.device)
317-
307+
buff = self.storage_policy.buffer.subset(rnd_ind)
308+
buff_x, buff_y, buff_t = [], [], []
309+
for bx, by, bt in buff:
310+
buff_x.append(bx)
311+
buff_y.append(by)
312+
buff_t.append(bt)
313+
buff_x = torch.stack(buff_x, dim=0).to(self.device)
314+
buff_y = torch.tensor(buff_y).to(self.device).long()
315+
buff_t = torch.tensor(buff_t).to(self.device).long()
318316
return buff_x, buff_y, buff_t
319317

320318

0 commit comments

Comments
 (0)