Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions train/loadData.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class LoadData:
'''
Class to laod the data
'''
def __init__(self, data_dir, classes, cached_data_file, normVal=1.10):
def __init__(self, data_dir, classes, cached_data_file, ignored_Id, normVal=1.10):
'''
:param data_dir: directory where the dataset is kept
:param classes: number of classes in the dataset
Expand All @@ -26,6 +26,7 @@ def __init__(self, data_dir, classes, cached_data_file, normVal=1.10):
self.trainAnnotList = list()
self.valAnnotList = list()
self.cached_data_file = cached_data_file
self.ignored_Id = ignored_Id

def compute_class_weights(self, histogram):
'''
Expand Down Expand Up @@ -85,9 +86,16 @@ def readFile(self, fileName, trainStg=False):
self.valAnnotList.append(label_file)

if max_val > (self.classes - 1) or min_val < 0:
print('Labels can take value between 0 and number of classes.')
print('Some problem with labels. Please check.')
print('Label Image ID: ' + label_file)
if max_val == self.ignored_Id:
print('Label id: %d has been ignored' %self.ignored_Id)
print('Label Image ID: ' + label_file)

else:
print('Labels can take value between 0 and number of classes.')
print('Some problem with labels.'
'You might want to set ignored_Id = <undefined class id>'
'Please check argument ignored_Id in main.py')
print('Label Image ID: ' + label_file)
no_files += 1

if trainStg == True:
Expand Down
21 changes: 14 additions & 7 deletions train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,16 @@ def val(args, val_loader, model, criterion):
# compute the loss
loss = criterion(output, target_var)

epoch_loss.append(loss.data[0])
#epoch_loss.append(loss.data[0])

epoch_loss.append(loss.item())

time_taken = time.time() - start_time

# compute the confusion matrix
iouEvalVal.addBatch(output.max(1)[1].data, target_var.data)

print('[%d/%d] loss: %.3f time: %.2f' % (i, total_batches, loss.data[0], time_taken))
print('[%d/%d] loss: %.3f time: %.2f' % (i, total_batches, loss.item(), time_taken))

average_epoch_loss_val = sum(epoch_loss) / len(epoch_loss)

Expand Down Expand Up @@ -102,13 +104,13 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
loss.backward()
optimizer.step()

epoch_loss.append(loss.data[0])
epoch_loss.append(loss.item())
time_taken = time.time() - start_time

#compute the confusion matrix
iouEvalTrain.addBatch(output.max(1)[1].data, target_var.data)

print('[%d/%d] loss: %.3f time:%.2f' % (i, total_batches, loss.data[0], time_taken))
print('[%d/%d] loss: %.3f time:%.2f' % (i, total_batches, loss.item(), time_taken))

average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)

Expand Down Expand Up @@ -149,7 +151,7 @@ def trainValidateSegmentation(args):
'''
# check if processed data file exists or not
if not os.path.isfile(args.cached_data_file):
dataLoad = ld.LoadData(args.data_dir, args.classes, args.cached_data_file)
dataLoad = ld.LoadData(args.data_dir, args.classes, args.cached_data_file, args.ignored_Id)
data = dataLoad.processData()
if data is None:
print('Error while pickling data. Please check.')
Expand All @@ -159,6 +161,7 @@ def trainValidateSegmentation(args):

q = args.q
p = args.p
ignored_Id = args.ignored_Id
# load the model
if not args.decoder:
model = net.ESPNet_Encoder(args.classes, p=p, q=q)
Expand Down Expand Up @@ -192,7 +195,8 @@ def trainValidateSegmentation(args):
if args.onGPU:
weight = weight.cuda()

criteria = CrossEntropyLoss2d(weight) #weight
#criteria = CrossEntropyLoss2d(weight) #weight
criteria = torch.nn.CrossEntropyLoss(weight=weight,ignore_index=ignored_Id) #ignore index -100 for default

if args.onGPU:
criteria = criteria.cuda()
Expand Down Expand Up @@ -386,7 +390,7 @@ def trainValidateSegmentation(args):
parser.add_argument('--scaleIn', type=int, default=8, help='For ESPNet-C, scaleIn=8. For ESPNet, scaleIn=1')
parser.add_argument('--max_epochs', type=int, default=300, help='Max. number of epochs')
parser.add_argument('--num_workers', type=int, default=4, help='No. of parallel threads')
parser.add_argument('--batch_size', type=int, default=12, help='Batch size. 12 for ESPNet-C and 6 for ESPNet. '
parser.add_argument('--batch_size', type=int, default=8, help='Batch size. 12 for ESPNet-C and 6 for ESPNet. '
'Change as per the GPU memory')
parser.add_argument('--step_loss', type=int, default=100, help='Decrease learning rate after how many epochs.')
parser.add_argument('--lr', type=float, default=5e-4, help='Initial learning rate')
Expand All @@ -402,6 +406,9 @@ def trainValidateSegmentation(args):
'Only used when training ESPNet')
parser.add_argument('--p', default=2, type=int, help='depth multiplier')
parser.add_argument('--q', default=8, type=int, help='depth multiplier')
parser.add_argument('--ignored_Id', default=255, type=int, help='ignoredTrainId for crossEntryLoss.'
'default 255 to ignore cityscapes background TrainId.'
'the background Id could be seen in cityscapesScripts/cityscapesscripts/helpers/label.py')

trainValidateSegmentation(parser.parse_args())