From 7b3ab53de710d9f65a75d35875b152a0aed62476 Mon Sep 17 00:00:00 2001 From: JonnaMat Date: Thu, 30 Mar 2023 10:32:55 +0200 Subject: [PATCH 1/2] add parser arg to set lr milestone for training --- trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trainer.py b/trainer.py index 1eef06c..7fd6576 100644 --- a/trainer.py +++ b/trainer.py @@ -55,6 +55,8 @@ parser.add_argument('--save-every', dest='save_every', help='Saves checkpoints at every specified number of epochs', type=int, default=10) +parser.add_argument('--lr-milestones', default=[100, 150], nargs='+', + help='list of epoch indices for multi step learning rate scheduler', type=int) best_prec1 = 0 @@ -116,9 +118,9 @@ def main(): optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - + print(args.lr_milestones) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, - milestones=[100, 150], last_epoch=args.start_epoch - 1) + milestones=args.lr_milestones, last_epoch=args.start_epoch - 1) if args.arch in ['resnet1202', 'resnet110']: # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up From 78a87f1c5d4f15c13598f7ebb3f72929c623303b Mon Sep 17 00:00:00 2001 From: JonnaMat Date: Thu, 30 Mar 2023 10:37:45 +0200 Subject: [PATCH 2/2] remove print statement from testing --- trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer.py b/trainer.py index 7fd6576..99ba853 100644 --- a/trainer.py +++ b/trainer.py @@ -118,7 +118,7 @@ def main(): optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - print(args.lr_milestones) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, last_epoch=args.start_epoch - 1)