From f4f5560c2864e5a1c9d5e60afdb02f84d6fb35aa Mon Sep 17 00:00:00 2001 From: zhuotaotian Date: Mon, 3 Aug 2020 23:16:09 +0800 Subject: [PATCH] first commit --- README.md | 82 ++++ config/coco/coco_split0_resnet101.yaml | 62 +++ config/coco/coco_split0_vgg.yaml | 62 +++ config/pascal/pascal_split0_resnet50.yaml | 61 +++ model/PFENet.py | 309 +++++++++++++++ model/resnet.py | 233 +++++++++++ model/vgg.py | 246 ++++++++++++ test.py | 263 +++++++++++++ test.sh | 15 + train.py | 452 ++++++++++++++++++++++ train.sh | 15 + util/config.py | 166 ++++++++ util/dataset.py | 245 ++++++++++++ util/transform.py | 376 ++++++++++++++++++ util/util.py | 146 +++++++ 15 files changed, 2733 insertions(+) create mode 100644 README.md create mode 100644 config/coco/coco_split0_resnet101.yaml create mode 100644 config/coco/coco_split0_vgg.yaml create mode 100644 config/pascal/pascal_split0_resnet50.yaml create mode 100755 model/PFENet.py create mode 100755 model/resnet.py create mode 100755 model/vgg.py create mode 100755 test.py create mode 100755 test.sh create mode 100755 train.py create mode 100755 train.sh create mode 100755 util/config.py create mode 100755 util/dataset.py create mode 100755 util/transform.py create mode 100755 util/util.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..51a8193 --- /dev/null +++ b/README.md @@ -0,0 +1,82 @@ +## PFENet +This is the implementation of our paper **PFENet: Prior Guided Feature Enrichment Network for Few-shot Segmentation** that has been accepted to IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI). + +## Get Started + +### Environment ++ torch==1.4.0 (torch version >= 1.0.1.post2 should be okay to run this repo) ++ numpy==1.18.4 ++ tensorboardX==1.8 ++ cv2==4.2.0 + + +### Datasets and Data Preparation + ++ PASCAL-5i is based on the PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) and SBD (http://home.bharathh.info/pubs/codes/SBD/download.html) where the val images should be excluded. + ++ COCO: https://cocodataset.org/#download. + +This code reads data from .txt files where each line contains the paths for image and the correcponding label respectively. Image and label paths are seperated by a space. Example is as follows: + + image_path_1 label_path_1 + + image_path_2 label_path_2 + + image_path_3 label_path_3 + + ... + + image_path_n label_path_n + +Then update the train/val/test list paths in the config files. + +### Run Demo / Test with Pretrained Models ++ Please download the pretrained models. ++ We provide **8 pre-trained models**: 4 ResNet-50 based [**models**](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155122171_link_cuhk_edu_hk/EW20i_eiTINDgJDqUqikNR4Bo-7kVFkLBkxGZ2_uorOJcw?e=4%3aSIRlwD&at=9) for PASCAL-5i and 4 VGG-16 based [**models**](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155122171_link_cuhk_edu_hk/EYS498D4TOZMtIb3WbQDGSQBsqxJHLSiMEAa49Iym0NO0A?e=4%3apRTPnj&at=9) for COCO. ++ Update the config file by speficifying the target **split** and **path** (`weights`) for loading the checkpoint. ++ Execute `mkdir initmodel` at the root directory. ++ Download the ImageNet pretrained [**backbones**](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155122171_link_cuhk_edu_hk/EQEY0JxITwVHisdVzusEqNUBNsf1CT8MsALdahUhaHrhlw?e=4%3a2o3XTL&at=9) and put them into the `initmodel` directory. ++ Then execute the command: + + `sh test_PFENet.sh {*dataset*} {*model_config*}` + +Example: Test PFENet with ResNet50 on the split 0 of PASCAL-5i: + + sh test_PFENet.sh pascal split0_resnet50 + + +### Train + +Execute this command at the root directory: + + sh train_PFENet.sh {*dataset*} {*model_config*} + + +## Related Repositories + +This project is built upon a very early version of **SemSeg**: https://github.com/hszhao/semseg. + +Other projects in few-shot segmentation: ++ OSLSM: https://github.com/lzzcd001/OSLSM ++ CANet: https://github.com/icoz69/CaNet ++ PANet: https://github.com/kaixin96/PANet ++ FSS-1000: https://github.com/HKUSTCV/FSS-1000 ++ AMP: https://github.com/MSiam/AdaptiveMaskedProxies ++ On the Texture Bias for FS Seg: https://github.com/rezazad68/fewshot-segmentation ++ SG-One: https://github.com/xiaomengyc/SG-One ++ FS Seg Propogation with Guided Networks: https://github.com/shelhamer/revolver + + +Many thanks to their greak work! + +## Citation + +If you find this project useful, please consider citing: +``` +@article{tian2020pfenet, + title={Prior Guided Feature Enrichment Netowkr for Few-Shot Segmentation}, + author={Tian, Zhuotao and Zhao, Hengshuang and Shu, Michelle and Yang, Zhicheng and Li, Ruiyu and Jia, Jiaya}, + journal={TPAMI}, + year={2020} +} +``` diff --git a/config/coco/coco_split0_resnet101.yaml b/config/coco/coco_split0_resnet101.yaml new file mode 100644 index 0000000..3d04acb --- /dev/null +++ b/config/coco/coco_split0_resnet101.yaml @@ -0,0 +1,62 @@ +DATA: + data_root: + train_list: + val_list: + classes: 2 + +TRAIN: + layers: 101 # 50 or 101 + sync_bn: False + train_h: 641 + train_w: 641 + val_size: 641 + scale_min: 0.8 # minimum random scale + scale_max: 1.25 # maximum random scale + rotate_min: -10 # minimum random rotate + rotate_max: 10 # maximum random rotate + zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] + ignore_label: 255 + padding_label: 255 + aux_weight: 1.0 + train_gpu: [0,1,2,3] # If only one gpu is used, batch size can be set to 8 and base_lr should be 0.005. + workers: 16 # data loader workers + batch_size: 32 # batch size for training. + batch_size_val: 1 # this version of code only support val batch = 1 + base_lr: 0.02 + epochs: 50 + start_epoch: 0 + power: 0.9 # 0 means no decay + momentum: 0.9 + weight_decay: 0.0001 + manual_seed: 321 + print_freq: 5 + save_freq: 20 + save_path: exp/coco/split0_resnet101/model + weight: + resume: # path to latest checkpoint (default: none) + evaluate: True + split: 0 + shot: 1 + vgg: False # whether to use vgg as the backbone + ppm_scales: [1.0, 0.5, 0.25, 0.125] + fix_random_seed_val: True + warmup: False + use_coco: True + use_split_coco: True + resized_val: True + ori_resize: True # use original label for evaluation + +## deprecated multi-processing training +Distributed: + dist_url: tcp://127.0.0.1:6789 + dist_backend: 'nccl' + multiprocessing_distributed: False + world_size: 1 + rank: 0 + use_apex: False + opt_level: 'O0' + keep_batchnorm_fp32: + loss_scale: + + + diff --git a/config/coco/coco_split0_vgg.yaml b/config/coco/coco_split0_vgg.yaml new file mode 100644 index 0000000..d9087d9 --- /dev/null +++ b/config/coco/coco_split0_vgg.yaml @@ -0,0 +1,62 @@ +DATA: + data_root: + train_list: + val_list: + classes: 2 + +TRAIN: + layers: 101 # 50 or 101 + sync_bn: False + train_h: 641 + train_w: 641 + val_size: 641 + scale_min: 0.8 # minimum random scale + scale_max: 1.25 # maximum random scale + rotate_min: -10 # minimum random rotate + rotate_max: 10 # maximum random rotate + zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] + ignore_label: 255 + padding_label: 255 + aux_weight: 1.0 + train_gpu: [0,1,2,3] # If only one gpu is used, batch size can be set to 8 and base_lr should be 0.005. + workers: 16 # data loader workers + batch_size: 32 # batch size for training. + batch_size_val: 1 # this version of code only support val batch = 1 + base_lr: 0.02 + epochs: 50 + start_epoch: 0 + power: 0.9 # 0 means no decay + momentum: 0.9 + weight_decay: 0.0001 + manual_seed: 321 + print_freq: 5 + save_freq: 20 + save_path: exp/coco/split0_vgg/model + weight: + resume: # path to latest checkpoint (default: none) + evaluate: True + split: 0 + shot: 1 + vgg: True # whether to use vgg as the backbone + ppm_scales: [1.0, 0.5, 0.25, 0.125] + fix_random_seed_val: True + warmup: False + use_coco: True + use_split_coco: True + resized_val: True + ori_resize: True # use original label for evaluation + +## deprecated multi-processing training +Distributed: + dist_url: tcp://127.0.0.1:6789 + dist_backend: 'nccl' + multiprocessing_distributed: False + world_size: 1 + rank: 0 + use_apex: False + opt_level: 'O0' + keep_batchnorm_fp32: + loss_scale: + + + diff --git a/config/pascal/pascal_split0_resnet50.yaml b/config/pascal/pascal_split0_resnet50.yaml new file mode 100644 index 0000000..6103fc7 --- /dev/null +++ b/config/pascal/pascal_split0_resnet50.yaml @@ -0,0 +1,61 @@ +DATA: + data_root: + train_list: + val_list: + classes: 2 + + +TRAIN: + layers: 50 + sync_bn: False + train_h: 473 + train_w: 473 + val_size: 473 + scale_min: 0.9 # minimum random scale + scale_max: 1.1 # maximum random scale + rotate_min: -10 # minimum random rotate + rotate_max: 10 # maximum random rotate + zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] + ignore_label: 255 + padding_label: 255 + aux_weight: 1.0 + train_gpu: [0] + workers: 8 # data loader workers + batch_size: 4 # batch size for training + batch_size_val: 1 + base_lr: 0.0025 + epochs: 200 + start_epoch: 0 + power: 0.9 # 0 means no decay + momentum: 0.9 + weight_decay: 0.0001 + manual_seed: 321 + print_freq: 5 + save_freq: 20 + save_path: exp/pascal/split0_resnet50/model + weight: # load weight for fine-tuning or testing + resume: # path to latest checkpoint (default: none) + evaluate: True + split: 0 + shot: 1 + vgg: False + ppm_scales: [60, 30, 15, 8] + fix_random_seed_val: True + warmup: False + use_coco: False + use_split_coco: False + resized_val: True + ori_resize: True # use original label for evaluation + +## deprecated multi-processing training +Distributed: + dist_url: tcp://127.0.0.1:6789 + dist_backend: 'nccl' + multiprocessing_distributed: False + world_size: 1 + rank: 0 + use_apex: False + opt_level: 'O0' + keep_batchnorm_fp32: + loss_scale: + diff --git a/model/PFENet.py b/model/PFENet.py new file mode 100755 index 0000000..f1a7253 --- /dev/null +++ b/model/PFENet.py @@ -0,0 +1,309 @@ +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np +import random +import time +import cv2 + +import model.resnet as models +import model.vgg as vgg_models + + +def Weighted_GAP(supp_feat, mask): + supp_feat = supp_feat * mask + feat_h, feat_w = supp_feat.shape[-2:][0], supp_feat.shape[-2:][1] + area = F.avg_pool2d(mask, (supp_feat.size()[2], supp_feat.size()[3])) * feat_h * feat_w + 0.0005 + supp_feat = F.avg_pool2d(input=supp_feat, kernel_size=supp_feat.shape[-2:]) * feat_h * feat_w / area + return supp_feat + +def get_vgg16_layer(model): + layer0_idx = range(0,7) + layer1_idx = range(7,14) + layer2_idx = range(14,24) + layer3_idx = range(24,34) + layer4_idx = range(34,43) + layers_0 = [] + layers_1 = [] + layers_2 = [] + layers_3 = [] + layers_4 = [] + for idx in layer0_idx: + layers_0 += [model.features[idx]] + for idx in layer1_idx: + layers_1 += [model.features[idx]] + for idx in layer2_idx: + layers_2 += [model.features[idx]] + for idx in layer3_idx: + layers_3 += [model.features[idx]] + for idx in layer4_idx: + layers_4 += [model.features[idx]] + layer0 = nn.Sequential(*layers_0) + layer1 = nn.Sequential(*layers_1) + layer2 = nn.Sequential(*layers_2) + layer3 = nn.Sequential(*layers_3) + layer4 = nn.Sequential(*layers_4) + return layer0,layer1,layer2,layer3,layer4 + +class PFENet(nn.Module): + def __init__(self, layers=50, classes=2, zoom_factor=8, \ + criterion=nn.CrossEntropyLoss(ignore_index=255), BatchNorm=nn.BatchNorm2d, \ + pretrained=True, sync_bn=True, shot=1, ppm_scales=[60, 30, 15, 8], vgg=False): + super(PFENet, self).__init__() + assert layers in [50, 101, 152] + print(ppm_scales) + assert classes > 1 + from torch.nn import BatchNorm2d as BatchNorm + self.zoom_factor = zoom_factor + self.criterion = criterion + self.shot = shot + self.ppm_scales = ppm_scales + self.vgg = vgg + + models.BatchNorm = BatchNorm + + if self.vgg: + print('INFO: Using VGG_16 bn') + vgg_models.BatchNorm = BatchNorm + vgg16 = vgg_models.vgg16_bn(pretrained=pretrained) + print(vgg16) + self.layer0, self.layer1, self.layer2, \ + self.layer3, self.layer4 = get_vgg16_layer(vgg16) + + else: + print('INFO: Using ResNet {}'.format(layers)) + if layers == 50: + resnet = models.resnet50(pretrained=pretrained) + elif layers == 101: + resnet = models.resnet101(pretrained=pretrained) + else: + resnet = models.resnet152(pretrained=pretrained) + self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool) + self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 + + for n, m in self.layer3.named_modules(): + if 'conv2' in n: + m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) + elif 'downsample.0' in n: + m.stride = (1, 1) + for n, m in self.layer4.named_modules(): + if 'conv2' in n: + m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) + elif 'downsample.0' in n: + m.stride = (1, 1) + + reduce_dim = 256 + if self.vgg: + fea_dim = 512 + 256 + else: + fea_dim = 1024 + 512 + + self.cls = nn.Sequential( + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True), + nn.Dropout2d(p=0.1), + nn.Conv2d(reduce_dim, classes, kernel_size=1) + ) + + self.down_query = nn.Sequential( + nn.Conv2d(fea_dim, reduce_dim, kernel_size=1, padding=0, bias=False), + nn.ReLU(inplace=True), + nn.Dropout2d(p=0.5) + ) + self.down_supp = nn.Sequential( + nn.Conv2d(fea_dim, reduce_dim, kernel_size=1, padding=0, bias=False), + nn.ReLU(inplace=True), + nn.Dropout2d(p=0.5) + ) + + self.pyramid_bins = ppm_scales + self.avgpool_list = [] + for bin in self.pyramid_bins: + if bin > 1: + self.avgpool_list.append( + nn.AdaptiveAvgPool2d(bin) + ) + + + factor = 1 + mask_add_num = 1 + self.init_merge = [] + self.beta_conv = [] + self.inner_cls = [] + for bin in self.pyramid_bins: + self.init_merge.append(nn.Sequential( + nn.Conv2d(reduce_dim*2 + mask_add_num, reduce_dim, kernel_size=1, padding=0, bias=False), + nn.ReLU(inplace=True), + )) + self.beta_conv.append(nn.Sequential( + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True) + )) + self.inner_cls.append(nn.Sequential( + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True), + nn.Dropout2d(p=0.1), + nn.Conv2d(reduce_dim, classes, kernel_size=1) + )) + self.init_merge = nn.ModuleList(self.init_merge) + self.beta_conv = nn.ModuleList(self.beta_conv) + self.inner_cls = nn.ModuleList(self.inner_cls) + + + self.res1 = nn.Sequential( + nn.Conv2d(reduce_dim*len(self.pyramid_bins), reduce_dim, kernel_size=1, padding=0, bias=False), + nn.ReLU(inplace=True), + ) + self.res2 = nn.Sequential( + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True), + ) + + self.GAP = nn.AdaptiveAvgPool2d(1) + + self.alpha_conv = [] + for idx in range(len(self.pyramid_bins)-1): + self.alpha_conv.append(nn.Sequential( + nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False), + nn.ReLU() + )) + self.alpha_conv = nn.ModuleList(self.alpha_conv) + + + + def forward(self, x, s_x=torch.FloatTensor(1,1,3,473,473).cuda(), s_y=torch.FloatTensor(1,1,473,473).cuda(), y=None): + x_size = x.size() + assert (x_size[2]-1) % 8 == 0 and (x_size[3]-1) % 8 == 0 + h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1) + w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1) + + # Query Feature + with torch.no_grad(): + query_feat_0 = self.layer0(x) + query_feat_1 = self.layer1(query_feat_0) + query_feat_2 = self.layer2(query_feat_1) + query_feat_3 = self.layer3(query_feat_2) + query_feat_4 = self.layer4(query_feat_3) + if self.vgg: + query_feat_2 = F.interpolate(query_feat_2, size=(query_feat_3.size(2),query_feat_3.size(3)), mode='bilinear', align_corners=True) + + query_feat = torch.cat([query_feat_3, query_feat_2], 1) + query_feat = self.down_query(query_feat) + + # Support Feature + supp_feat_list = [] + final_supp_list = [] + mask_list = [] + for i in range(self.shot): + mask = (s_y[:,i,:,:] == 1).float().unsqueeze(1) + mask_list.append(mask) + with torch.no_grad(): + supp_feat_0 = self.layer0(s_x[:,i,:,:,:]) + supp_feat_1 = self.layer1(supp_feat_0) + supp_feat_2 = self.layer2(supp_feat_1) + supp_feat_3 = self.layer3(supp_feat_2) + mask = F.interpolate(mask, size=(supp_feat_3.size(2), supp_feat_3.size(3)), mode='bilinear', align_corners=True) + supp_feat_4 = self.layer4(supp_feat_3*mask) + final_supp_list.append(supp_feat_4) + if self.vgg: + supp_feat_2 = F.interpolate(supp_feat_2, size=(supp_feat_3.size(2),supp_feat_3.size(3)), mode='bilinear', align_corners=True) + + supp_feat = torch.cat([supp_feat_3, supp_feat_2], 1) + supp_feat = self.down_supp(supp_feat) + supp_feat = Weighted_GAP(supp_feat, mask) + supp_feat_list.append(supp_feat) + + + corr_query_mask_list = [] + cosine_eps = 1e-7 + for i, tmp_supp_feat in enumerate(final_supp_list): + resize_size = tmp_supp_feat.size(2) + tmp_mask = F.interpolate(mask_list[i], size=(resize_size, resize_size), mode='bilinear', align_corners=True) + + tmp_supp_feat_4 = tmp_supp_feat * tmp_mask + q = query_feat_4 + s = tmp_supp_feat_4 + bsize, ch_sz, sp_sz, _ = q.size()[:] + + tmp_query = q + tmp_query = tmp_query.contiguous().view(bsize, ch_sz, -1) + tmp_query_norm = torch.norm(tmp_query, 2, 1, True) + + tmp_supp = s + tmp_supp = tmp_supp.contiguous().view(bsize, ch_sz, -1) + tmp_supp = tmp_supp.contiguous().permute(0, 2, 1) + tmp_supp_norm = torch.norm(tmp_supp, 2, 2, True) + + similarity = torch.bmm(tmp_supp, tmp_query)/(torch.bmm(tmp_supp_norm, tmp_query_norm) + cosine_eps) + similarity = similarity.max(1)[0].view(bsize, sp_sz*sp_sz) + similarity = (similarity - similarity.min(1)[0].unsqueeze(1))/(similarity.max(1)[0].unsqueeze(1) - similarity.min(1)[0].unsqueeze(1) + cosine_eps) + corr_query = similarity.view(bsize, 1, sp_sz, sp_sz) + corr_query = F.interpolate(corr_query, size=(query_feat_3.size()[2], query_feat_3.size()[3]), mode='bilinear', align_corners=True) + corr_query_mask_list.append(corr_query) + corr_query_mask = torch.cat(corr_query_mask_list, 1).mean(1).unsqueeze(1) + corr_query_mask = F.interpolate(corr_query_mask, size=(query_feat.size(2), query_feat.size(3)), mode='bilinear', align_corners=True) + + if self.shot > 1: + supp_feat = supp_feat_list[0] + for i in range(1, len(supp_feat_list)): + supp_feat += supp_feat_list[i] + supp_feat /= len(supp_feat_list) + + out_list = [] + pyramid_feat_list = [] + + for idx, tmp_bin in enumerate(self.pyramid_bins): + if tmp_bin <= 1.0: + bin = int(query_feat.shape[2] * tmp_bin) + query_feat_bin = nn.AdaptiveAvgPool2d(bin)(query_feat) + else: + bin = tmp_bin + query_feat_bin = self.avgpool_list[idx](query_feat) + supp_feat_bin = supp_feat.expand(-1, -1, bin, bin) + corr_mask_bin = F.interpolate(corr_query_mask, size=(bin, bin), mode='bilinear', align_corners=True) + merge_feat_bin = torch.cat([query_feat_bin, supp_feat_bin, corr_mask_bin], 1) + merge_feat_bin = self.init_merge[idx](merge_feat_bin) + + if idx >= 1: + pre_feat_bin = pyramid_feat_list[idx-1].clone() + pre_feat_bin = F.interpolate(pre_feat_bin, size=(bin, bin), mode='bilinear', align_corners=True) + rec_feat_bin = torch.cat([merge_feat_bin, pre_feat_bin], 1) + merge_feat_bin = self.alpha_conv[idx-1](rec_feat_bin) + merge_feat_bin + + merge_feat_bin = self.beta_conv[idx](merge_feat_bin) + merge_feat_bin + inner_out_bin = self.inner_cls[idx](merge_feat_bin) + merge_feat_bin = F.interpolate(merge_feat_bin, size=(query_feat.size(2), query_feat.size(3)), mode='bilinear', align_corners=True) + pyramid_feat_list.append(merge_feat_bin) + out_list.append(inner_out_bin) + + query_feat = torch.cat(pyramid_feat_list, 1) + query_feat = self.res1(query_feat) + query_feat = self.res2(query_feat) + query_feat + out = self.cls(query_feat) + + + # Output Part + if self.zoom_factor != 1: + out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True) + + if self.training: + main_loss = self.criterion(out, y.long()) + aux_loss = torch.zeros_like(main_loss).cuda() + + for idx_k in range(len(out_list)): + inner_out = out_list[idx_k] + inner_out = F.interpolate(inner_out, size=(h, w), mode='bilinear', align_corners=True) + aux_loss = aux_loss + self.criterion(inner_out, y.long()) + aux_loss = aux_loss / len(out_list) + return out.max(1)[1], main_loss, aux_loss + else: + return out + + + + + diff --git a/model/resnet.py b/model/resnet.py new file mode 100755 index 0000000..6f165e3 --- /dev/null +++ b/model/resnet.py @@ -0,0 +1,233 @@ +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo + +BatchNorm = nn.BatchNorm2d + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = BatchNorm(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, deep_base=True): + super(ResNet, self).__init__() + self.deep_base = deep_base + if not self.deep_base: + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = BatchNorm(64) + self.relu = nn.ReLU(inplace=True) + else: + self.inplanes = 128 + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, BatchNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + if self.deep_base: + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=True, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + model_path = './initmodel/resnet50_v2.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + model_path = './initmodel/resnet101_v2.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + model_path = './initmodel/resnet152_v2.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model diff --git a/model/vgg.py b/model/vgg.py new file mode 100755 index 0000000..1d8e5a0 --- /dev/null +++ b/model/vgg.py @@ -0,0 +1,246 @@ +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +BatchNorm = nn.BatchNorm2d + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + + +model_urls = { + 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', + 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', + 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', + 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', + 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', + 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', + 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', + 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', +} + + +class VGG(nn.Module): + + def __init__(self, features, num_classes=1000, init_weights=True): + super(VGG, self).__init__() + self.features = features + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + if init_weights: + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, BatchNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, BatchNorm(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +cfg = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def vgg11(pretrained=False, **kwargs): + """VGG 11-layer model (configuration "A") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['A']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) + return model + + +def vgg11_bn(pretrained=False, **kwargs): + """VGG 11-layer model (configuration "A") with batch normalization + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) + return model + + +def vgg13(pretrained=False, **kwargs): + """VGG 13-layer model (configuration "B") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['B']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) + return model + + +def vgg13_bn(pretrained=False, **kwargs): + """VGG 13-layer model (configuration "B") with batch normalization + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) + return model + + +def vgg16(pretrained=False, **kwargs): + """VGG 16-layer model (configuration "D") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['D']), **kwargs) + if pretrained: + #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) + model_path = './initmodel/vgg16.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model + + +def vgg16_bn(pretrained=False, **kwargs): + """VGG 16-layer model (configuration "D") with batch normalization + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) + if pretrained: + #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) + model_path = './initmodel/vgg16_bn.pth' + model.load_state_dict(torch.load(model_path), strict=False) + return model + + +def vgg19(pretrained=False, **kwargs): + """VGG 19-layer model (configuration "E") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['E']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) + return model + + +def vgg19_bn(pretrained=False, **kwargs): + """VGG 19-layer model (configuration 'E') with batch normalization + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) + return model + +if __name__ =='__main__': + import os + os.environ["CUDA_VISIBLE_DEVICES"] = '7' + input = torch.rand(4, 3, 473, 473).cuda() + target = torch.rand(4, 473, 473).cuda()*1.0 + model = vgg16_bn(pretrained=False).cuda() + model.train() + layer0_idx = range(0,6) + layer1_idx = range(6,13) + layer2_idx = range(13,23) + layer3_idx = range(23,33) + layer4_idx = range(34,43) + #layer4_idx = range(34,43) + print(model.features) + layers_0 = [] + layers_1 = [] + layers_2 = [] + layers_3 = [] + layers_4 = [] + for idx in layer0_idx: + layers_0 += [model.features[idx]] + for idx in layer1_idx: + layers_1 += [model.features[idx]] + for idx in layer2_idx: + layers_2 += [model.features[idx]] + for idx in layer3_idx: + layers_3 += [model.features[idx]] + for idx in layer4_idx: + layers_4 += [model.features[idx]] + + layer0 = nn.Sequential(*layers_0) + layer1 = nn.Sequential(*layers_1) + layer2 = nn.Sequential(*layers_2) + layer3 = nn.Sequential(*layers_3) + layer4 = nn.Sequential(*layers_4) + + output = layer0(input) + print(layer0) + print('layer 0: {}'.format(output.size())) + output = layer1(output) + print(layer1) + print('layer 1: {}'.format(output.size())) + output = layer2(output) + print(layer2) + print('layer 2: {}'.format(output.size())) + output = layer3(output) + print(layer3) + print('layer 3: {}'.format(output.size())) + output = layer4(output) + print(layer4) + print('layer 4: {}'.format(output.size())) + \ No newline at end of file diff --git a/test.py b/test.py new file mode 100755 index 0000000..5a13e15 --- /dev/null +++ b/test.py @@ -0,0 +1,263 @@ +import os +import random +import time +import cv2 +import numpy as np +import logging +import argparse + +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.multiprocessing as mp +import torch.distributed as dist +from tensorboardX import SummaryWriter + +from model.PFENet import PFENet +from util import dataset +from util import transform, config +from util.util import AverageMeter, poly_learning_rate, intersectionAndUnionGPU + +cv2.ocl.setUseOpenCL(False) +cv2.setNumThreads(0) + + +def get_parser(): + parser = argparse.ArgumentParser(description='PyTorch Semantic Segmentation') + parser.add_argument('--config', type=str, default='config/ade20k/ade20k_pspnet50.yaml', help='config file') + parser.add_argument('opts', help='see config/ade20k/ade20k_pspnet50.yaml for all options', default=None, nargs=argparse.REMAINDER) + args = parser.parse_args() + assert args.config is not None + cfg = config.load_cfg_from_cfg_file(args.config) + if args.opts is not None: + cfg = config.merge_cfg_from_list(cfg, args.opts) + return cfg + + +def get_logger(): + logger_name = "main-logger" + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" + handler.setFormatter(logging.Formatter(fmt)) + logger.addHandler(handler) + return logger + + +def worker_init_fn(worker_id): + random.seed(args.manual_seed + worker_id) + + +def main_process(): + return not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0) + + +def main(): + args = get_parser() + assert args.classes > 1 + assert args.zoom_factor in [1, 2, 4, 8] + assert (args.train_h - 1) % 8 == 0 and (args.train_w - 1) % 8 == 0 + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.train_gpu) + if args.manual_seed is not None: + cudnn.benchmark = False + cudnn.deterministic = True + torch.cuda.manual_seed(args.manual_seed) + np.random.seed(args.manual_seed) + torch.manual_seed(args.manual_seed) + torch.cuda.manual_seed_all(args.manual_seed) + random.seed(args.manual_seed) + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + args.ngpus_per_node = len(args.train_gpu) + if len(args.train_gpu) == 1: + args.sync_bn = False + args.distributed = False + args.multiprocessing_distributed = False + if args.multiprocessing_distributed: + args.world_size = args.ngpus_per_node * args.world_size + mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args)) + else: + main_worker(args.train_gpu, args.ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, argss): + global args + args = argss + + BatchNorm = nn.BatchNorm2d + + criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) + + model = PFENet(layers=args.layers, classes=2, zoom_factor=8, \ + criterion=nn.CrossEntropyLoss(ignore_index=255), BatchNorm=BatchNorm, \ + pretrained=True, shot=args.shot, ppm_scales=args.ppm_scales, vgg=args.vgg) + + global logger, writer + logger = get_logger() + writer = SummaryWriter(args.save_path) + logger.info("=> creating model ...") + logger.info("Classes: {}".format(args.classes)) + logger.info(model) + print(args) + + model = torch.nn.DataParallel(model.cuda()) + + if args.weight: + if os.path.isfile(args.weight): + logger.info("=> loading weight '{}'".format(args.weight)) + checkpoint = torch.load(args.weight) + model.load_state_dict(checkpoint['state_dict']) + logger.info("=> loaded weight '{}'".format(args.weight)) + else: + logger.info("=> no weight found at '{}'".format(args.weight)) + + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + + assert args.split in [0, 1, 2, 3, 999] + + if args.resized_val: + val_transform = transform.Compose([ + transform.Resize(size=args.val_size), + transform.ToTensor(), + transform.Normalize(mean=mean, std=std)]) + else: + val_transform = transform.Compose([ + transform.test_Resize(size=args.val_size), + transform.ToTensor(), + transform.Normalize(mean=mean, std=std)]) + val_data = dataset.SemData(split=args.split, shot=args.shot, data_root=args.data_root, \ + data_list=args.val_list, transform=val_transform, mode='val', \ + use_coco=args.use_coco, use_split_coco=args.use_split_coco) + val_sampler = None + val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) + + loss_val, mIoU_val, mAcc_val, allAcc_val, class_miou = validate(val_loader, model, criterion) + +def validate(val_loader, model, criterion): + if main_process(): + logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>') + batch_time = AverageMeter() + model_time = AverageMeter() + data_time = AverageMeter() + loss_meter = AverageMeter() + intersection_meter = AverageMeter() + union_meter = AverageMeter() + target_meter = AverageMeter() + if args.use_coco: + split_gap = 20 + else: + split_gap = 5 + class_intersection_meter = [0]*split_gap + class_union_meter = [0]*split_gap + + if args.manual_seed is not None and args.fix_random_seed_val: + torch.cuda.manual_seed(args.manual_seed) + np.random.seed(args.manual_seed) + torch.manual_seed(args.manual_seed) + torch.cuda.manual_seed_all(args.manual_seed) + random.seed(args.manual_seed) + + model.eval() + end = time.time() + if args.split != 999: + if args.use_coco: + test_num = 20000 + else: + test_num = 5000 + else: + test_num = len(val_loader) + assert test_num % args.batch_size_val == 0 + iter_num = 0 + total_time = 0 + for e in range(20): + for i, (input, target, s_input, s_mask, subcls, ori_label) in enumerate(val_loader): + if (iter_num-1) * args.batch_size_val >= test_num: + break + iter_num += 1 + data_time.update(time.time() - end) + input = input.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + ori_label = ori_label.cuda(non_blocking=True) + start_time = time.time() + output = model(s_x=s_input, s_y=s_mask, x=input, y=target) + total_time = total_time + 1 + model_time.update(time.time() - start_time) + + if args.ori_resize: + longerside = max(ori_label.size(1), ori_label.size(2)) + backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255 + backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label + target = backmask.clone().long() + + output = F.interpolate(output, size=target.size()[1:], mode='bilinear', align_corners=True) + loss = criterion(output, target) + + n = input.size(0) + loss = torch.mean(loss) + + output = output.max(1)[1] + + intersection, union, new_target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label) + intersection, union, target, new_target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy(), new_target.cpu().numpy() + intersection_meter.update(intersection), union_meter.update(union), target_meter.update(new_target) + + subcls = subcls[0].cpu().numpy()[0] + class_intersection_meter[(subcls-1)%split_gap] += intersection[1] + class_union_meter[(subcls-1)%split_gap] += union[1] + + accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) + loss_meter.update(loss.item(), input.size(0)) + batch_time.update(time.time() - end) + end = time.time() + if ((i + 1) % (test_num/100) == 0) and main_process(): + logger.info('Test: [{}/{}] ' + 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' + 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) ' + 'Accuracy {accuracy:.4f}.'.format(iter_num* args.batch_size_val, test_num, + data_time=data_time, + batch_time=batch_time, + loss_meter=loss_meter, + accuracy=accuracy)) + + iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) + accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) + mIoU = np.mean(iou_class) + mAcc = np.mean(accuracy_class) + allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) + + + class_iou_class = [] + class_miou = 0 + for i in range(len(class_intersection_meter)): + class_iou = class_intersection_meter[i]/(class_union_meter[i]+ 1e-10) + class_iou_class.append(class_iou) + class_miou += class_iou + class_miou = class_miou*1.0 / len(class_intersection_meter) + logger.info('meanIoU---Val result: mIoU {:.4f}.'.format(class_miou)) + for i in range(split_gap): + logger.info('Class_{} Result: iou {:.4f}.'.format(i+1, class_iou_class[i])) + + + if main_process(): + logger.info('FBIoU---Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc)) + for i in range(args.classes): + logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i])) + logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<') + + print('avg inference time: {:.4f}, count: {}'.format(model_time.avg, test_num)) + return loss_meter.avg, mIoU, mAcc, allAcc, class_miou + + +if __name__ == '__main__': + main() diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..d949fdb --- /dev/null +++ b/test.sh @@ -0,0 +1,15 @@ +#!/bin/sh +PARTITION=Segmentation + +dataset=$1 +exp_name=$2 +exp_dir=exp/${dataset}/${exp_name} +model_dir=${exp_dir}/model +result_dir=${exp_dir}/result +config=config/${dataset}/${dataset}_${exp_name}.yaml + +mkdir -p ${model_dir} ${result_dir} +now=$(date +"%Y%m%d_%H%M%S") +cp test.sh test.py ${config} ${exp_dir} + +python3 -u test.py --config=${config} 2>&1 | tee ${result_dir}/test-$now.log diff --git a/train.py b/train.py new file mode 100755 index 0000000..dda9a38 --- /dev/null +++ b/train.py @@ -0,0 +1,452 @@ +import os +import random +import time +import cv2 +import numpy as np +import logging +import argparse + +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.multiprocessing as mp +import torch.distributed as dist +from tensorboardX import SummaryWriter + +from model.PFENet import PFENet +from util import dataset +from util import transform, config +from util.util import AverageMeter, poly_learning_rate, intersectionAndUnionGPU + +cv2.ocl.setUseOpenCL(False) +cv2.setNumThreads(0) + + +def get_parser(): + parser = argparse.ArgumentParser(description='PyTorch Semantic Segmentation') + parser.add_argument('--config', type=str, default='config/ade20k/ade20k_pspnet50.yaml', help='config file') + parser.add_argument('opts', help='see config/ade20k/ade20k_pspnet50.yaml for all options', default=None, nargs=argparse.REMAINDER) + args = parser.parse_args() + assert args.config is not None + cfg = config.load_cfg_from_cfg_file(args.config) + if args.opts is not None: + cfg = config.merge_cfg_from_list(cfg, args.opts) + return cfg + + +def get_logger(): + logger_name = "main-logger" + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" + handler.setFormatter(logging.Formatter(fmt)) + logger.addHandler(handler) + return logger + + +def worker_init_fn(worker_id): + random.seed(args.manual_seed + worker_id) + + +def main_process(): + return not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0) + + +def main(): + args = get_parser() + assert args.classes > 1 + assert args.zoom_factor in [1, 2, 4, 8] + assert (args.train_h - 1) % 8 == 0 and (args.train_w - 1) % 8 == 0 + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.train_gpu) + if args.manual_seed is not None: + cudnn.benchmark = False + cudnn.deterministic = True + torch.cuda.manual_seed(args.manual_seed) + np.random.seed(args.manual_seed) + torch.manual_seed(args.manual_seed) + torch.cuda.manual_seed_all(args.manual_seed) + random.seed(args.manual_seed) + + ### multi-processing training is deprecated + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + args.ngpus_per_node = len(args.train_gpu) + if len(args.train_gpu) == 1: + args.sync_bn = False # sync_bn is deprecated + args.distributed = False + args.multiprocessing_distributed = False + if args.multiprocessing_distributed: + args.world_size = args.ngpus_per_node * args.world_size + mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args)) + else: + main_worker(args.train_gpu, args.ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, argss): + global args + args = argss + + BatchNorm = nn.BatchNorm2d + + criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) + + model = PFENet(layers=args.layers, classes=2, zoom_factor=8, \ + criterion=nn.CrossEntropyLoss(ignore_index=255), BatchNorm=BatchNorm, \ + pretrained=True, shot=args.shot, ppm_scales=args.ppm_scales, vgg=args.vgg) + + for param in model.layer0.parameters(): + param.requires_grad = False + for param in model.layer1.parameters(): + param.requires_grad = False + for param in model.layer2.parameters(): + param.requires_grad = False + for param in model.layer3.parameters(): + param.requires_grad = False + for param in model.layer4.parameters(): + param.requires_grad = False + + optimizer = torch.optim.SGD( + [ + {'params': model.down_query.parameters()}, + {'params': model.down_supp.parameters()}, + {'params': model.init_merge.parameters()}, + {'params': model.alpha_conv.parameters()}, + {'params': model.beta_conv.parameters()}, + {'params': model.inner_cls.parameters()}, + {'params': model.res1.parameters()}, + {'params': model.res2.parameters()}, + {'params': model.cls.parameters()}], + lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) + + global logger, writer + logger = get_logger() + writer = SummaryWriter(args.save_path) + logger.info("=> creating model ...") + logger.info("Classes: {}".format(args.classes)) + logger.info(model) + print(args) + + model = torch.nn.DataParallel(model.cuda()) + + if args.weight: + if os.path.isfile(args.weight): + logger.info("=> loading weight '{}'".format(args.weight)) + checkpoint = torch.load(args.weight) + model.load_state_dict(checkpoint['state_dict']) + logger.info("=> loaded weight '{}'".format(args.weight)) + else: + logger.info("=> no weight found at '{}'".format(args.weight)) + + if args.resume: + if os.path.isfile(args.resume): + logger.info("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda()) + args.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) + else: + logger.info("=> no checkpoint found at '{}'".format(args.resume)) + + + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + + assert args.split in [0, 1, 2, 3, 999] + train_transform = [ + transform.RandScale([args.scale_min, args.scale_max]), + transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.padding_label), + transform.RandomGaussianBlur(), + transform.RandomHorizontalFlip(), + transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.padding_label), + transform.ToTensor(), + transform.Normalize(mean=mean, std=std)] + train_transform = transform.Compose(train_transform) + train_data = dataset.SemData(split=args.split, shot=args.shot, data_root=args.data_root, \ + data_list=args.train_list, transform=train_transform, mode='train', \ + use_coco=args.use_coco, use_split_coco=args.use_split_coco) + + train_sampler = None + train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) + if args.evaluate: + if args.resized_val: + val_transform = transform.Compose([ + transform.Resize(size=args.val_size), + transform.ToTensor(), + transform.Normalize(mean=mean, std=std)]) + else: + val_transform = transform.Compose([ + transform.test_Resize(size=args.val_size), + transform.ToTensor(), + transform.Normalize(mean=mean, std=std)]) + val_data = dataset.SemData(split=args.split, shot=args.shot, data_root=args.data_root, \ + data_list=args.val_list, transform=val_transform, mode='val', \ + use_coco=args.use_coco, use_split_coco=args.use_split_coco) + val_sampler = None + val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) + + max_iou = 0. + filename = 'PFENet.pth' + + for epoch in range(args.start_epoch, args.epochs): + if args.fix_random_seed_val: + torch.cuda.manual_seed(args.manual_seed + epoch) + np.random.seed(args.manual_seed + epoch) + torch.manual_seed(args.manual_seed + epoch) + torch.cuda.manual_seed_all(args.manual_seed + epoch) + random.seed(args.manual_seed + epoch) + + epoch_log = epoch + 1 + loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, optimizer, epoch) + if main_process(): + writer.add_scalar('loss_train', loss_train, epoch_log) + writer.add_scalar('mIoU_train', mIoU_train, epoch_log) + writer.add_scalar('mAcc_train', mAcc_train, epoch_log) + writer.add_scalar('allAcc_train', allAcc_train, epoch_log) + + if args.evaluate and (epoch % 2 == 0 or (args.epochs<=50 and epoch%1==0)): + loss_val, mIoU_val, mAcc_val, allAcc_val, class_miou = validate(val_loader, model, criterion) + if main_process(): + writer.add_scalar('loss_val', loss_val, epoch_log) + writer.add_scalar('mIoU_val', mIoU_val, epoch_log) + writer.add_scalar('mAcc_val', mAcc_val, epoch_log) + writer.add_scalar('class_miou_val', class_miou, epoch_log) + writer.add_scalar('allAcc_val', allAcc_val, epoch_log) + if class_miou > max_iou: + max_iou = class_miou + if os.path.exists(filename): + os.remove(filename) + filename = args.save_path + '/train_epoch_' + str(epoch) + '_'+str(max_iou)+'.pth' + logger.info('Saving checkpoint to: ' + filename) + torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename) + + filename = args.save_path + '/final.pth' + logger.info('Saving checkpoint to: ' + filename) + torch.save({'epoch': args.epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename) + + +def train(train_loader, model, optimizer, epoch): + batch_time = AverageMeter() + data_time = AverageMeter() + main_loss_meter = AverageMeter() + aux_loss_meter = AverageMeter() + loss_meter = AverageMeter() + intersection_meter = AverageMeter() + union_meter = AverageMeter() + target_meter = AverageMeter() + + model.train() + end = time.time() + max_iter = args.epochs * len(train_loader) + vis_key = 0 + print('Warmup: {}'.format(args.warmup)) + for i, (input, target, s_input, s_mask, subcls) in enumerate(train_loader): + data_time.update(time.time() - end) + current_iter = epoch * len(train_loader) + i + 1 + index_split = -1 + if args.base_lr > 1e-6: + poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power, index_split=index_split, warmup=args.warmup, warmup_step=len(train_loader)//2) + + s_input = s_input.cuda(non_blocking=True) + s_mask = s_mask.cuda(non_blocking=True) + input = input.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + output, main_loss, aux_loss = model(s_x=s_input, s_y=s_mask, x=input, y=target) + + if not args.multiprocessing_distributed: + main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss) + loss = main_loss + args.aux_weight * aux_loss + optimizer.zero_grad() + + loss.backward() + optimizer.step() + n = input.size(0) + if args.multiprocessing_distributed: + main_loss, aux_loss, loss = main_loss.detach() * n, aux_loss * n, loss * n + count = target.new_tensor([n], dtype=torch.long) + dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(loss), dist.all_reduce(count) + n = count.item() + main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n + + intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label) + if args.multiprocessing_distributed: + dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target) + intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy() + intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target) + + accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) + main_loss_meter.update(main_loss.item(), n) + aux_loss_meter.update(aux_loss.item(), n) + loss_meter.update(loss.item(), n) + batch_time.update(time.time() - end) + end = time.time() + + remain_iter = max_iter - current_iter + remain_time = remain_iter * batch_time.avg + t_m, t_s = divmod(remain_time, 60) + t_h, t_m = divmod(t_m, 60) + remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s)) + + if (i + 1) % args.print_freq == 0 and main_process(): + logger.info('Epoch: [{}/{}][{}/{}] ' + 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' + 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Remain {remain_time} ' + 'MainLoss {main_loss_meter.val:.4f} ' + 'AuxLoss {aux_loss_meter.val:.4f} ' + 'Loss {loss_meter.val:.4f} ' + 'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader), + batch_time=batch_time, + data_time=data_time, + remain_time=remain_time, + main_loss_meter=main_loss_meter, + aux_loss_meter=aux_loss_meter, + loss_meter=loss_meter, + accuracy=accuracy)) + if main_process(): + writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter) + writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter) + writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter) + writer.add_scalar('allAcc_train_batch', accuracy, current_iter) + + iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) + accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) + mIoU = np.mean(iou_class) + mAcc = np.mean(accuracy_class) + allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) + + if main_process(): + logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch, args.epochs, mIoU, mAcc, allAcc)) + for i in range(args.classes): + logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i])) + return main_loss_meter.avg, mIoU, mAcc, allAcc + + +def validate(val_loader, model, criterion): + if main_process(): + logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>') + batch_time = AverageMeter() + model_time = AverageMeter() + data_time = AverageMeter() + loss_meter = AverageMeter() + intersection_meter = AverageMeter() + union_meter = AverageMeter() + target_meter = AverageMeter() + if args.use_coco: + split_gap = 20 + else: + split_gap = 5 + class_intersection_meter = [0]*split_gap + class_union_meter = [0]*split_gap + + if args.manual_seed is not None and args.fix_random_seed_val: + torch.cuda.manual_seed(args.manual_seed) + np.random.seed(args.manual_seed) + torch.manual_seed(args.manual_seed) + torch.cuda.manual_seed_all(args.manual_seed) + random.seed(args.manual_seed) + + model.eval() + end = time.time() + if args.split != 999: + if args.use_coco: + test_num = 20000 + else: + test_num = 5000 + else: + test_num = len(val_loader) + assert test_num % args.batch_size_val == 0 + iter_num = 0 + total_time = 0 + for e in range(10): + for i, (input, target, s_input, s_mask, subcls, ori_label) in enumerate(val_loader): + if (iter_num-1) * args.batch_size_val >= test_num: + break + iter_num += 1 + data_time.update(time.time() - end) + input = input.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + ori_label = ori_label.cuda(non_blocking=True) + start_time = time.time() + output = model(s_x=s_input, s_y=s_mask, x=input, y=target) + total_time = total_time + 1 + model_time.update(time.time() - start_time) + + if args.ori_resize: + longerside = max(ori_label.size(1), ori_label.size(2)) + backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255 + backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label + target = backmask.clone().long() + + output = F.interpolate(output, size=target.size()[1:], mode='bilinear', align_corners=True) + loss = criterion(output, target) + + n = input.size(0) + loss = torch.mean(loss) + + output = output.max(1)[1] + + intersection, union, new_target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label) + intersection, union, target, new_target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy(), new_target.cpu().numpy() + intersection_meter.update(intersection), union_meter.update(union), target_meter.update(new_target) + + subcls = subcls[0].cpu().numpy()[0] + class_intersection_meter[(subcls-1)%split_gap] += intersection[1] + class_union_meter[(subcls-1)%split_gap] += union[1] + + accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) + loss_meter.update(loss.item(), input.size(0)) + batch_time.update(time.time() - end) + end = time.time() + if ((i + 1) % (test_num/100) == 0) and main_process(): + logger.info('Test: [{}/{}] ' + 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' + 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) ' + 'Accuracy {accuracy:.4f}.'.format(iter_num* args.batch_size_val, test_num, + data_time=data_time, + batch_time=batch_time, + loss_meter=loss_meter, + accuracy=accuracy)) + + iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) + accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) + mIoU = np.mean(iou_class) + mAcc = np.mean(accuracy_class) + allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) + + + class_iou_class = [] + class_miou = 0 + for i in range(len(class_intersection_meter)): + class_iou = class_intersection_meter[i]/(class_union_meter[i]+ 1e-10) + class_iou_class.append(class_iou) + class_miou += class_iou + class_miou = class_miou*1.0 / len(class_intersection_meter) + logger.info('meanIoU---Val result: mIoU {:.4f}.'.format(class_miou)) + for i in range(split_gap): + logger.info('Class_{} Result: iou {:.4f}.'.format(i+1, class_iou_class[i])) + + + if main_process(): + logger.info('FBIoU---Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc)) + for i in range(args.classes): + logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i])) + logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<') + + print('avg inference time: {:.4f}, count: {}'.format(model_time.avg, test_num)) + return loss_meter.avg, mIoU, mAcc, allAcc, class_miou + + +if __name__ == '__main__': + main() diff --git a/train.sh b/train.sh new file mode 100755 index 0000000..9db0e4b --- /dev/null +++ b/train.sh @@ -0,0 +1,15 @@ +#!/bin/sh +PARTITION=Segmentation + +dataset=$1 +exp_name=$2 +exp_dir=exp/${dataset}/${exp_name} +model_dir=${exp_dir}/model +result_dir=${exp_dir}/result +config=config/${dataset}/${dataset}_${exp_name}.yaml + +mkdir -p ${model_dir} ${result_dir} +now=$(date +"%Y%m%d_%H%M%S") +cp train.sh train.py ${config} ${exp_dir} + +python3 -u train.py --config=${config} 2>&1 | tee ${result_dir}/train-$now.log diff --git a/util/config.py b/util/config.py new file mode 100755 index 0000000..0aa7b11 --- /dev/null +++ b/util/config.py @@ -0,0 +1,166 @@ +# ----------------------------------------------------------------------------- +# Functions for parsing args +# ----------------------------------------------------------------------------- +import yaml +import os +from ast import literal_eval +import copy + + +class CfgNode(dict): + """ + CfgNode represents an internal node in the configuration tree. It's a simple + dict-like container that allows for attribute-based access to keys. + """ + + def __init__(self, init_dict=None, key_list=None, new_allowed=False): + # Recursively convert nested dictionaries in init_dict into CfgNodes + init_dict = {} if init_dict is None else init_dict + key_list = [] if key_list is None else key_list + for k, v in init_dict.items(): + if type(v) is dict: + # Convert dict to CfgNode + init_dict[k] = CfgNode(v, key_list=key_list + [k]) + super(CfgNode, self).__init__(init_dict) + + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + self[name] = value + + def __str__(self): + def _indent(s_, num_spaces): + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + r = "" + s = [] + for k, v in sorted(self.items()): + seperator = "\n" if isinstance(v, CfgNode) else " " + attr_str = "{}:{}{}".format(str(k), seperator, str(v)) + attr_str = _indent(attr_str, 2) + s.append(attr_str) + r += "\n".join(s) + return r + + def __repr__(self): + return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) + + +def load_cfg_from_cfg_file(file): + cfg = {} + assert os.path.isfile(file) and file.endswith('.yaml'), \ + '{} is not a yaml file'.format(file) + + with open(file, 'r') as f: + cfg_from_file = yaml.safe_load(f) + + for key in cfg_from_file: + for k, v in cfg_from_file[key].items(): + cfg[k] = v + + cfg = CfgNode(cfg) + return cfg + + +def merge_cfg_from_list(cfg, cfg_list): + new_cfg = copy.deepcopy(cfg) + assert len(cfg_list) % 2 == 0 + for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): + subkey = full_key.split('.')[-1] + assert subkey in cfg, 'Non-existent key: {}'.format(full_key) + value = _decode_cfg_value(v) + value = _check_and_coerce_cfg_value_type( + value, cfg[subkey], subkey, full_key + ) + setattr(new_cfg, subkey, value) + + return new_cfg + + +def _decode_cfg_value(v): + """Decodes a raw config value (e.g., from a yaml config files or command + line argument) into a Python object. + """ + # All remaining processing is only applied to strings + if not isinstance(v, str): + return v + # Try to interpret `v` as a: + # string, number, tuple, list, dict, boolean, or None + try: + v = literal_eval(v) + # The following two excepts allow v to pass through when it represents a + # string. + # + # Longer explanation: + # The type of v is always a string (before calling literal_eval), but + # sometimes it *represents* a string and other times a data structure, like + # a list. In the case that v represents a string, what we got back from the + # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is + # ok with '"foo"', but will raise a ValueError if given 'foo'. In other + # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval + # will raise a SyntaxError. + except ValueError: + pass + except SyntaxError: + pass + return v + + +def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): + """Checks that `replacement`, which is intended to replace `original` is of + the right type. The type is correct if it matches exactly or is one of a few + cases in which the type can be easily coerced. + """ + original_type = type(original) + replacement_type = type(replacement) + + # The types must match (with some exceptions) + if replacement_type == original_type: + return replacement + + # Cast replacement from from_type to to_type if the replacement and original + # types match from_type and to_type + def conditional_cast(from_type, to_type): + if replacement_type == from_type and original_type == to_type: + return True, to_type(replacement) + else: + return False, None + + # Conditionally casts + # list <-> tuple + casts = [(tuple, list), (list, tuple)] + # For py2: allow converting from str (bytes) to a unicode string + try: + casts.append((str, unicode)) # noqa: F821 + except Exception: + pass + + for (from_type, to_type) in casts: + converted, converted_value = conditional_cast(from_type, to_type) + if converted: + return converted_value + + raise ValueError( + "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " + "key: {}".format( + original_type, replacement_type, original, replacement, full_key + ) + ) + + +def _assert_with_logging(cond, msg): + if not cond: + logger.debug(msg) + assert cond, msg + diff --git a/util/dataset.py b/util/dataset.py new file mode 100755 index 0000000..d9c2898 --- /dev/null +++ b/util/dataset.py @@ -0,0 +1,245 @@ +import os +import os.path +import cv2 +import numpy as np + +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch +import random +import time +from tqdm import tqdm + +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] + + + +def is_image_file(filename): + filename_lower = filename.lower() + return any(filename_lower.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(split=0, data_root=None, data_list=None, sub_list=None): + assert split in [0, 1, 2, 3, 10, 11, 999] + if not os.path.isfile(data_list): + raise (RuntimeError("Image list file do not exist: " + data_list + "\n")) + + # Shaban uses these lines to remove small objects: + # if util.change_coordinates(mask, 32.0, 0.0).sum() > 2: + # filtered_item.append(item) + # which means the mask will be downsampled to 1/32 of the original size and the valid area should be larger than 2, + # therefore the area in original size should be accordingly larger than 2 * 32 * 32 + image_label_list = [] + list_read = open(data_list).readlines() + print("Processing data...".format(sub_list)) + sub_class_file_list = {} + for sub_c in sub_list: + sub_class_file_list[sub_c] = [] + + for l_idx in tqdm(range(len(list_read))): + line = list_read[l_idx] + line = line.strip() + line_split = line.split(' ') + image_name = os.path.join(data_root, line_split[0]) + label_name = os.path.join(data_root, line_split[1]) + item = (image_name, label_name) + label = cv2.imread(label_name, cv2.IMREAD_GRAYSCALE) + label_class = np.unique(label).tolist() + + if 0 in label_class: + label_class.remove(0) + if 255 in label_class: + label_class.remove(255) + + new_label_class = [] + for c in label_class: + if c in sub_list: + tmp_label = np.zeros_like(label) + target_pix = np.where(label == c) + tmp_label[target_pix[0],target_pix[1]] = 1 + if tmp_label.sum() >= 2 * 32 * 32: + new_label_class.append(c) + + label_class = new_label_class + + if len(label_class) > 0: + image_label_list.append(item) + for c in label_class: + if c in sub_list: + sub_class_file_list[c].append(item) + + print("Checking image&label pair {} list done! ".format(split)) + return image_label_list, sub_class_file_list + + + +class SemData(Dataset): + def __init__(self, split=3, shot=1, data_root=None, data_list=None, transform=None, mode='train', use_coco=False, use_split_coco=False): + assert mode in ['train', 'val', 'test'] + + self.mode = mode + self.split = split + self.shot = shot + self.data_root = data_root + + if not use_coco: + self.class_list = list(range(1, 21)) #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + if self.split == 3: + self.sub_list = list(range(1, 16)) #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + self.sub_val_list = list(range(16, 21)) #[16,17,18,19,20] + elif self.split == 2: + self.sub_list = list(range(1, 11)) + list(range(16, 21)) #[1,2,3,4,5,6,7,8,9,10,16,17,18,19,20] + self.sub_val_list = list(range(11, 16)) #[11,12,13,14,15] + elif self.split == 1: + self.sub_list = list(range(1, 6)) + list(range(11, 21)) #[1,2,3,4,5,11,12,13,14,15,16,17,18,19,20] + self.sub_val_list = list(range(6, 11)) #[6,7,8,9,10] + elif self.split == 0: + self.sub_list = list(range(6, 21)) #[6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] + self.sub_val_list = list(range(1, 6)) #[1,2,3,4,5] + + else: + if use_split_coco: + print('INFO: using SPLIT COCO') + self.class_list = list(range(1, 81)) + if self.split == 3: + self.sub_val_list = list(range(4, 81, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 2: + self.sub_val_list = list(range(3, 80, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 1: + self.sub_val_list = list(range(2, 79, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + elif self.split == 0: + self.sub_val_list = list(range(1, 78, 4)) + self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) + else: + print('INFO: using COCO') + self.class_list = list(range(1, 81)) + if self.split == 3: + self.sub_list = list(range(1, 61)) + self.sub_val_list = list(range(61, 81)) + elif self.split == 2: + self.sub_list = list(range(1, 41)) + list(range(61, 81)) + self.sub_val_list = list(range(41, 61)) + elif self.split == 1: + self.sub_list = list(range(1, 21)) + list(range(41, 81)) + self.sub_val_list = list(range(21, 41)) + elif self.split == 0: + self.sub_list = list(range(21, 81)) + self.sub_val_list = list(range(1, 21)) + + print('sub_list: ', self.sub_list) + print('sub_val_list: ', self.sub_val_list) + + if self.mode == 'train': + self.data_list, self.sub_class_file_list = make_dataset(split, data_root, data_list, self.sub_list) + assert len(self.sub_class_file_list.keys()) == len(self.sub_list) + elif self.mode == 'val': + self.data_list, self.sub_class_file_list = make_dataset(split, data_root, data_list, self.sub_val_list) + assert len(self.sub_class_file_list.keys()) == len(self.sub_val_list) + self.transform = transform + + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + label_class = [] + image_path, label_path = self.data_list[index] + image = cv2.imread(image_path, cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = np.float32(image) + label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) + + if image.shape[0] != label.shape[0] or image.shape[1] != label.shape[1]: + raise (RuntimeError("Query Image & label shape mismatch: " + image_path + " " + label_path + "\n")) + label_class = np.unique(label).tolist() + if 0 in label_class: + label_class.remove(0) + if 255 in label_class: + label_class.remove(255) + new_label_class = [] + for c in label_class: + if c in self.sub_val_list: + if self.mode == 'val' or self.mode == 'test': + new_label_class.append(c) + if c in self.sub_list: + if self.mode == 'train': + new_label_class.append(c) + label_class = new_label_class + assert len(label_class) > 0 + + + class_chosen = label_class[random.randint(1,len(label_class))-1] + class_chosen = class_chosen + target_pix = np.where(label == class_chosen) + ignore_pix = np.where(label == 255) + label[:,:] = 0 + if target_pix[0].shape[0] > 0: + label[target_pix[0],target_pix[1]] = 1 + label[ignore_pix[0],ignore_pix[1]] = 255 + + + file_class_chosen = self.sub_class_file_list[class_chosen] + num_file = len(file_class_chosen) + + support_image_path_list = [] + support_label_path_list = [] + support_idx_list = [] + for k in range(self.shot): + support_idx = random.randint(1,num_file)-1 + support_image_path = image_path + support_label_path = label_path + while((support_image_path == image_path and support_label_path == label_path) or support_idx in support_idx_list): + support_idx = random.randint(1,num_file)-1 + support_image_path, support_label_path = file_class_chosen[support_idx] + support_idx_list.append(support_idx) + support_image_path_list.append(support_image_path) + support_label_path_list.append(support_label_path) + + support_image_list = [] + support_label_list = [] + subcls_list = [] + for k in range(self.shot): + if self.mode == 'train': + subcls_list.append(self.sub_list.index(class_chosen)) + else: + subcls_list.append(self.sub_val_list.index(class_chosen)) + support_image_path = support_image_path_list[k] + support_label_path = support_label_path_list[k] + support_image = cv2.imread(support_image_path, cv2.IMREAD_COLOR) + support_image = cv2.cvtColor(support_image, cv2.COLOR_BGR2RGB) + support_image = np.float32(support_image) + support_label = cv2.imread(support_label_path, cv2.IMREAD_GRAYSCALE) + target_pix = np.where(support_label == class_chosen) + ignore_pix = np.where(support_label == 255) + support_label[:,:] = 0 + support_label[target_pix[0],target_pix[1]] = 1 + support_label[ignore_pix[0],ignore_pix[1]] = 255 + if support_image.shape[0] != support_label.shape[0] or support_image.shape[1] != support_label.shape[1]: + raise (RuntimeError("Support Image & label shape mismatch: " + support_image_path + " " + support_label_path + "\n")) + support_image_list.append(support_image) + support_label_list.append(support_label) + assert len(support_label_list) == self.shot and len(support_image_list) == self.shot + + raw_label = label.copy() + if self.transform is not None: + image, label = self.transform(image, label) + for k in range(self.shot): + support_image_list[k], support_label_list[k] = self.transform(support_image_list[k], support_label_list[k]) + + s_xs = support_image_list + s_ys = support_label_list + s_x = s_xs[0].unsqueeze(0) + for i in range(1, self.shot): + s_x = torch.cat([s_xs[i].unsqueeze(0), s_x], 0) + s_y = s_ys[0].unsqueeze(0) + for i in range(1, self.shot): + s_y = torch.cat([s_ys[i].unsqueeze(0), s_y], 0) + + if self.mode == 'train': + return image, label, s_x, s_y, subcls_list + else: + return image, label, s_x, s_y, subcls_list, raw_label + diff --git a/util/transform.py b/util/transform.py new file mode 100755 index 0000000..d215d7f --- /dev/null +++ b/util/transform.py @@ -0,0 +1,376 @@ +import random +import math +import numpy as np +import numbers +import collections +import cv2 + +import torch + +manual_seed = 123 +torch.manual_seed(manual_seed) +np.random.seed(manual_seed) +torch.manual_seed(manual_seed) +torch.cuda.manual_seed_all(manual_seed) +random.seed(manual_seed) + +class Compose(object): + # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()]) + def __init__(self, segtransform): + self.segtransform = segtransform + + def __call__(self, image, label): + for t in self.segtransform: + image, label = t(image, label) + return image, label + +import time +class ToTensor(object): + # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). + def __call__(self, image, label): + if not isinstance(image, np.ndarray) or not isinstance(label, np.ndarray): + raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray" + "[eg: data readed by cv2.imread()].\n")) + if len(image.shape) > 3 or len(image.shape) < 2: + raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n")) + if len(image.shape) == 2: + image = np.expand_dims(image, axis=2) + if not len(label.shape) == 2: + raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n")) + + image = torch.from_numpy(image.transpose((2, 0, 1))) + if not isinstance(image, torch.FloatTensor): + image = image.float() + label = torch.from_numpy(label) + if not isinstance(label, torch.LongTensor): + label = label.long() + return image, label + + +class Normalize(object): + # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std + def __init__(self, mean, std=None): + if std is None: + assert len(mean) > 0 + else: + assert len(mean) == len(std) + self.mean = mean + self.std = std + + def __call__(self, image, label): + if self.std is None: + for t, m in zip(image, self.mean): + t.sub_(m) + else: + for t, m, s in zip(image, self.mean, self.std): + t.sub_(m).div_(s) + return image, label + + +class Resize(object): + # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). + def __init__(self, size): + self.size = size + + def __call__(self, image, label): + + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + + def find_new_hw(ori_h, ori_w, test_size): + if ori_h >= ori_w: + ratio = test_size*1.0 / ori_h + new_h = test_size + new_w = int(ori_w * ratio) + elif ori_w > ori_h: + ratio = test_size*1.0 / ori_w + new_h = int(ori_h * ratio) + new_w = test_size + + if new_h % 8 != 0: + new_h = (int(new_h /8))*8 + else: + new_h = new_h + if new_w % 8 != 0: + new_w = (int(new_w /8))*8 + else: + new_w = new_w + return new_h, new_w + + test_size = self.size + new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) + #new_h, new_w = test_size, test_size + image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) + back_crop = np.zeros((test_size, test_size, 3)) + # back_crop[:,:,0] = mean[0] + # back_crop[:,:,1] = mean[1] + # back_crop[:,:,2] = mean[2] + back_crop[:new_h, :new_w, :] = image_crop + image = back_crop + + s_mask = label + new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) + #new_h, new_w = test_size, test_size + s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),interpolation=cv2.INTER_NEAREST) + back_crop_s_mask = np.ones((test_size, test_size)) * 255 + back_crop_s_mask[:new_h, :new_w] = s_mask + label = back_crop_s_mask + + return image, label + + +class test_Resize(object): + # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). + def __init__(self, size): + self.size = size + + def __call__(self, image, label): + + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + + def find_new_hw(ori_h, ori_w, test_size): + if max(ori_h, ori_w) > test_size: + if ori_h >= ori_w: + ratio = test_size*1.0 / ori_h + new_h = test_size + new_w = int(ori_w * ratio) + elif ori_w > ori_h: + ratio = test_size*1.0 / ori_w + new_h = int(ori_h * ratio) + new_w = test_size + + if new_h % 8 != 0: + new_h = (int(new_h /8))*8 + else: + new_h = new_h + if new_w % 8 != 0: + new_w = (int(new_w /8))*8 + else: + new_w = new_w + return new_h, new_w + else: + return ori_h, ori_w + + test_size = self.size + new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) + if new_w != image.shape[0] or new_h != image.shape[1]: + image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) + else: + image_crop = image.copy() + back_crop = np.zeros((test_size, test_size, 3)) + back_crop[:new_h, :new_w, :] = image_crop + image = back_crop + + s_mask = label + new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) + if new_w != s_mask.shape[0] or new_h != s_mask.shape[1]: + s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),interpolation=cv2.INTER_NEAREST) + back_crop_s_mask = np.ones((test_size, test_size)) * 255 + back_crop_s_mask[:new_h, :new_w] = s_mask + label = back_crop_s_mask + + return image, label + + +class RandScale(object): + # Randomly resize image & label with scale factor in [scale_min, scale_max] + def __init__(self, scale, aspect_ratio=None): + assert (isinstance(scale, collections.Iterable) and len(scale) == 2) + if isinstance(scale, collections.Iterable) and len(scale) == 2 \ + and isinstance(scale[0], numbers.Number) and isinstance(scale[1], numbers.Number) \ + and 0 < scale[0] < scale[1]: + self.scale = scale + else: + raise (RuntimeError("segtransform.RandScale() scale param error.\n")) + if aspect_ratio is None: + self.aspect_ratio = aspect_ratio + elif isinstance(aspect_ratio, collections.Iterable) and len(aspect_ratio) == 2 \ + and isinstance(aspect_ratio[0], numbers.Number) and isinstance(aspect_ratio[1], numbers.Number) \ + and 0 < aspect_ratio[0] < aspect_ratio[1]: + self.aspect_ratio = aspect_ratio + else: + raise (RuntimeError("segtransform.RandScale() aspect_ratio param error.\n")) + + def __call__(self, image, label): + temp_scale = self.scale[0] + (self.scale[1] - self.scale[0]) * random.random() + temp_aspect_ratio = 1.0 + if self.aspect_ratio is not None: + temp_aspect_ratio = self.aspect_ratio[0] + (self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random() + temp_aspect_ratio = math.sqrt(temp_aspect_ratio) + scale_factor_x = temp_scale * temp_aspect_ratio + scale_factor_y = temp_scale / temp_aspect_ratio + image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST) + return image, label + + +class Crop(object): + """Crops the given ndarray image (H*W*C or H*W). + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is made. + """ + def __init__(self, size, crop_type='center', padding=None, ignore_label=255): + self.size = size + if isinstance(size, int): + self.crop_h = size + self.crop_w = size + elif isinstance(size, collections.Iterable) and len(size) == 2 \ + and isinstance(size[0], int) and isinstance(size[1], int) \ + and size[0] > 0 and size[1] > 0: + self.crop_h = size[0] + self.crop_w = size[1] + else: + raise (RuntimeError("crop size error.\n")) + if crop_type == 'center' or crop_type == 'rand': + self.crop_type = crop_type + else: + raise (RuntimeError("crop type error: rand | center\n")) + if padding is None: + self.padding = padding + elif isinstance(padding, list): + if all(isinstance(i, numbers.Number) for i in padding): + self.padding = padding + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if len(padding) != 3: + raise (RuntimeError("padding channel is not equal with 3\n")) + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if isinstance(ignore_label, int): + self.ignore_label = ignore_label + else: + raise (RuntimeError("ignore_label should be an integer number\n")) + + def __call__(self, image, label): + h, w = label.shape + + + pad_h = max(self.crop_h - h, 0) + pad_w = max(self.crop_w - w, 0) + pad_h_half = int(pad_h / 2) + pad_w_half = int(pad_w / 2) + if pad_h > 0 or pad_w > 0: + if self.padding is None: + raise (RuntimeError("segtransform.Crop() need padding while padding argument is None\n")) + image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.padding) + label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label) + h, w = label.shape + raw_label = label + raw_image = image + + if self.crop_type == 'rand': + h_off = random.randint(0, h - self.crop_h) + w_off = random.randint(0, w - self.crop_w) + else: + h_off = int((h - self.crop_h) / 2) + w_off = int((w - self.crop_w) / 2) + image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + raw_pos_num = np.sum(raw_label == 1) + pos_num = np.sum(label == 1) + crop_cnt = 0 + while(pos_num < 0.85*raw_pos_num and crop_cnt<=30): + image = raw_image + label = raw_label + if self.crop_type == 'rand': + h_off = random.randint(0, h - self.crop_h) + w_off = random.randint(0, w - self.crop_w) + else: + h_off = int((h - self.crop_h) / 2) + w_off = int((w - self.crop_w) / 2) + image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] + raw_pos_num = np.sum(raw_label == 1) + pos_num = np.sum(label == 1) + crop_cnt += 1 + if crop_cnt >= 50: + image = cv2.resize(raw_image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) + label = cv2.resize(raw_label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) + + if image.shape != (self.size[0], self.size[0], 3): + image = cv2.resize(image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) + + return image, label + + +class RandRotate(object): + # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max] + def __init__(self, rotate, padding, ignore_label=255, p=0.5): + assert (isinstance(rotate, collections.Iterable) and len(rotate) == 2) + if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) and rotate[0] < rotate[1]: + self.rotate = rotate + else: + raise (RuntimeError("segtransform.RandRotate() scale param error.\n")) + assert padding is not None + assert isinstance(padding, list) and len(padding) == 3 + if all(isinstance(i, numbers.Number) for i in padding): + self.padding = padding + else: + raise (RuntimeError("padding in RandRotate() should be a number list\n")) + assert isinstance(ignore_label, int) + self.ignore_label = ignore_label + self.p = p + + def __call__(self, image, label): + if random.random() < self.p: + angle = self.rotate[0] + (self.rotate[1] - self.rotate[0]) * random.random() + h, w = label.shape + matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=self.padding) + label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=self.ignore_label) + return image, label + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, image, label): + if random.random() < self.p: + image = cv2.flip(image, 1) + label = cv2.flip(label, 1) + return image, label + + +class RandomVerticalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, image, label): + if random.random() < self.p: + image = cv2.flip(image, 0) + label = cv2.flip(label, 0) + return image, label + + +class RandomGaussianBlur(object): + def __init__(self, radius=5): + self.radius = radius + + def __call__(self, image, label): + if random.random() < 0.5: + image = cv2.GaussianBlur(image, (self.radius, self.radius), 0) + return image, label + + +class RGB2BGR(object): + # Converts image from RGB order to BGR order, for model initialized from Caffe + def __call__(self, image, label): + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + return image, label + + +class BGR2RGB(object): + # Converts image from BGR order to RGB order, for model initialized from Pytorch + def __call__(self, image, label): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image, label diff --git a/util/util.py b/util/util.py new file mode 100755 index 0000000..7457428 --- /dev/null +++ b/util/util.py @@ -0,0 +1,146 @@ +import os +import numpy as np +from PIL import Image + +import torch +from torch import nn +import torch.nn.init as initer + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def step_learning_rate(optimizer, base_lr, epoch, step_epoch, multiplier=0.1): + """Sets the learning rate to the base LR decayed by 10 every step epochs""" + lr = base_lr * (multiplier ** (epoch // step_epoch)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def poly_learning_rate(optimizer, base_lr, curr_iter, max_iter, power=0.9, index_split=-1, scale_lr=10., warmup=False, warmup_step=500): + """poly learning rate policy""" + if warmup and curr_iter < warmup_step: + lr = base_lr * (0.1 + 0.9 * (curr_iter/warmup_step)) + else: + lr = base_lr * (1 - float(curr_iter) / max_iter) ** power + + if curr_iter % 50 == 0: + print('Base LR: {:.4f}, Curr LR: {:.4f}, Warmup: {}.'.format(base_lr, lr, (warmup and curr_iter < warmup_step))) + + for index, param_group in enumerate(optimizer.param_groups): + if index <= index_split: + param_group['lr'] = lr + else: + param_group['lr'] = lr * scale_lr + + +def intersectionAndUnion(output, target, K, ignore_index=255): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert (output.ndim in [1, 2, 3]) + assert output.shape == target.shape + output = output.reshape(output.size).copy() + target = target.reshape(target.size) + output[np.where(target == ignore_index)[0]] = ignore_index + intersection = output[np.where(output == target)[0]] + area_intersection, _ = np.histogram(intersection, bins=np.arange(K+1)) + area_output, _ = np.histogram(output, bins=np.arange(K+1)) + area_target, _ = np.histogram(target, bins=np.arange(K+1)) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + + +def intersectionAndUnionGPU(output, target, K, ignore_index=255): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert (output.dim() in [1, 2, 3]) + assert output.shape == target.shape + output = output.view(-1) + target = target.view(-1) + output[target == ignore_index] = ignore_index + intersection = output[output == target] + area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1) + area_output = torch.histc(output, bins=K, min=0, max=K-1) + area_target = torch.histc(target, bins=K, min=0, max=K-1) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + +def check_mkdir(dir_name): + if not os.path.exists(dir_name): + os.mkdir(dir_name) + + +def check_makedirs(dir_name): + if not os.path.exists(dir_name): + os.makedirs(dir_name) + + +def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'): + """ + :param model: Pytorch Model which is nn.Module + :param conv: 'kaiming' or 'xavier' + :param batchnorm: 'normal' or 'constant' + :param linear: 'kaiming' or 'xavier' + :param lstm: 'kaiming' or 'xavier' + """ + for m in model.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + if conv == 'kaiming': + initer.kaiming_normal_(m.weight) + elif conv == 'xavier': + initer.xavier_normal_(m.weight) + else: + raise ValueError("init type of conv error.\n") + if m.bias is not None: + initer.constant_(m.bias, 0) + + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):#, BatchNorm1d, BatchNorm2d, BatchNorm3d)): + if batchnorm == 'normal': + initer.normal_(m.weight, 1.0, 0.02) + elif batchnorm == 'constant': + initer.constant_(m.weight, 1.0) + else: + raise ValueError("init type of batchnorm error.\n") + initer.constant_(m.bias, 0.0) + + elif isinstance(m, nn.Linear): + if linear == 'kaiming': + initer.kaiming_normal_(m.weight) + elif linear == 'xavier': + initer.xavier_normal_(m.weight) + else: + raise ValueError("init type of linear error.\n") + if m.bias is not None: + initer.constant_(m.bias, 0) + + elif isinstance(m, nn.LSTM): + for name, param in m.named_parameters(): + if 'weight' in name: + if lstm == 'kaiming': + initer.kaiming_normal_(param) + elif lstm == 'xavier': + initer.xavier_normal_(param) + else: + raise ValueError("init type of lstm error.\n") + elif 'bias' in name: + initer.constant_(param, 0) + + +def colorize(gray, palette): + # gray: numpy array of the label and 1*3N size list palette + color = Image.fromarray(gray.astype(np.uint8)).convert('P') + color.putpalette(palette) + return color