Skip to content

Commit c3ac4f5

Browse files
committed
Support amp and resume training in fastface
AMP in partial-fc needs to be done only on backbone; In order to impl `resume training`, need to save & load different part of classifier weight in each GPU.
1 parent 91ff631 commit c3ac4f5

15 files changed

+432
-178
lines changed

projects/FastFace/configs/face_base.yml

+8-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ MODEL:
44
PIXEL_MEAN: [127.5, 127.5, 127.5]
55
PIXEL_STD: [127.5, 127.5, 127.5]
66

7+
BACKBONE:
8+
NAME: build_iresnet_backbone
9+
710
HEADS:
811
NAME: FaceHead
912
WITH_BNNECK: True
@@ -30,7 +33,7 @@ MODEL:
3033
DATASETS:
3134
REC_PATH: /export/home/DATA/Glint360k/train.rec
3235
NAMES: ("MS1MV2",)
33-
TESTS: ("CPLFW", "VGG2_FP", "CALFW", "CFP_FF", "CFP_FP", "AgeDB_30", "LFW")
36+
TESTS: ("CFP_FP", "AgeDB_30", "LFW")
3437

3538
INPUT:
3639
SIZE_TRAIN: [0,] # No need of resize
@@ -47,10 +50,10 @@ DATALOADER:
4750
SOLVER:
4851
MAX_EPOCH: 20
4952
AMP:
50-
ENABLED: False
53+
ENABLED: True
5154

5255
OPT: SGD
53-
BASE_LR: 0.1
56+
BASE_LR: 0.05
5457
MOMENTUM: 0.9
5558

5659
SCHED: MultiStepLR
@@ -59,10 +62,10 @@ SOLVER:
5962
BIAS_LR_FACTOR: 1.
6063
WEIGHT_DECAY: 0.0005
6164
WEIGHT_DECAY_BIAS: 0.0005
62-
IMS_PER_BATCH: 512
65+
IMS_PER_BATCH: 256
6366

6467
WARMUP_FACTOR: 0.1
65-
WARMUP_ITERS: 5000
68+
WARMUP_ITERS: 0
6669

6770
CHECKPOINT_PERIOD: 1
6871

projects/FastFace/configs/r50_ir.yml

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@ _BASE_: face_base.yml
33
MODEL:
44

55
BACKBONE:
6-
NAME: build_resnetIR_backbone
76
DEPTH: 50x
87
FEAT_DIM: 25088 # 512x7x7
9-
WITH_SE: True
8+
DROPOUT: 0.
109

1110
HEADS:
1211
PFC:
1312
ENABLED: True
1413

15-
OUTPUT_DIR: projects/FastFace/logs/ir_se50-glink360k-pfc0.1
14+
OUTPUT_DIR: projects/FastFace/logs/pfc0.1_insightface

projects/FastFace/fastface/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from .modeling import *
88
from .config import add_face_cfg
99
from .trainer import FaceTrainer
10+
from .datasets import *

projects/FastFace/fastface/config.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,7 @@ def add_face_cfg(cfg):
1212

1313
_C.DATASETS.REC_PATH = ""
1414

15+
_C.MODEL.BACKBONE.DROPOUT = 0.
16+
1517
_C.MODEL.HEADS.PFC = CN({"ENABLED": False})
1618
_C.MODEL.HEADS.PFC.SAMPLE_RATE = 0.1

projects/FastFace/fastface/datasets/ms1mv2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, root="datasets", **kwargs):
2323
required_files = [self.dataset_dir]
2424
self.check_before_run(required_files)
2525

26-
train = self.process_dirs()
26+
train = self.process_dirs()[:10000]
2727
super().__init__(train, [], [], **kwargs)
2828

2929
def process_dirs(self):

projects/FastFace/fastface/modeling/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
from .partial_fc import PartialFC
88
from .face_baseline import FaceBaseline
99
from .face_head import FaceHead
10-
from .resnet_ir import build_resnetIR_backbone
10+
from .iresnet import build_iresnet_backbone

projects/FastFace/fastface/modeling/face_baseline.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
55
"""
66

7+
import torch
78
from fastreid.modeling.meta_arch import Baseline
89
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
910

@@ -13,12 +14,28 @@ class FaceBaseline(Baseline):
1314
def __init__(self, cfg):
1415
super().__init__(cfg)
1516
self.pfc_enabled = cfg.MODEL.HEADS.PFC.ENABLED
17+
self.amp_enabled = cfg.SOLVER.AMP.ENABLED
1618

17-
def losses(self, outputs, gt_labels):
19+
def forward(self, batched_inputs):
1820
if not self.pfc_enabled:
19-
return super().losses(outputs, gt_labels)
21+
return super().forward(batched_inputs)
22+
23+
images = self.preprocess_image(batched_inputs)
24+
with torch.cuda.amp.autocast(self.amp_enabled):
25+
features = self.backbone(images)
26+
features = features.float() if self.amp_enabled else features
27+
28+
if self.training:
29+
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
30+
targets = batched_inputs["targets"]
31+
32+
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
33+
# may be larger than that in the original dataset, so the circle/arcface will
34+
# throw an error. We just set all the targets to 0 to avoid this problem.
35+
if targets.sum() < 0: targets.zero_()
36+
37+
outputs = self.heads(features, targets)
38+
return outputs, targets
2039
else:
21-
# model parallel with partial-fc
22-
# cls layer and loss computation in partial_fc.py
23-
pred_features = outputs["features"]
24-
return pred_features, gt_labels
40+
outputs = self.heads(features)
41+
return outputs

projects/FastFace/fastface/modeling/face_head.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,4 @@ def forward(self, features, targets=None):
3030
pool_feat = self.pool_layer(features)
3131
neck_feat = self.bottleneck(pool_feat)
3232
neck_feat = neck_feat[..., 0, 0]
33-
34-
if not self.training:
35-
return neck_feat
36-
37-
return {
38-
"features": neck_feat,
39-
}
33+
return neck_feat
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# encoding: utf-8
2+
"""
3+
@author: xingyu liao
4+
5+
"""
6+
7+
import torch
8+
from torch import nn
9+
10+
from fastreid.layers import get_norm
11+
from fastreid.modeling.backbones import BACKBONE_REGISTRY
12+
13+
14+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
15+
"""3x3 convolution with padding"""
16+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17+
padding=dilation, groups=groups, bias=False, dilation=dilation)
18+
19+
20+
def conv1x1(in_planes, out_planes, stride=1):
21+
"""1x1 convolution"""
22+
return nn.Conv2d(in_planes,
23+
out_planes,
24+
kernel_size=1,
25+
stride=stride,
26+
bias=False)
27+
28+
29+
class IBasicBlock(nn.Module):
30+
expansion = 1
31+
32+
def __init__(self, inplanes, planes, bn_norm, stride=1, downsample=None,
33+
groups=1, base_width=64, dilation=1):
34+
super().__init__()
35+
if groups != 1 or base_width != 64:
36+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
37+
if dilation > 1:
38+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
39+
self.bn1 = get_norm(bn_norm, inplanes)
40+
self.conv1 = conv3x3(inplanes, planes)
41+
self.bn2 = get_norm(bn_norm, planes)
42+
self.prelu = nn.PReLU(planes)
43+
self.conv2 = conv3x3(planes, planes, stride)
44+
self.bn3 = get_norm(bn_norm, planes)
45+
self.downsample = downsample
46+
self.stride = stride
47+
48+
def forward(self, x):
49+
identity = x
50+
out = self.bn1(x)
51+
out = self.conv1(out)
52+
out = self.bn2(out)
53+
out = self.prelu(out)
54+
out = self.conv2(out)
55+
out = self.bn3(out)
56+
if self.downsample is not None:
57+
identity = self.downsample(x)
58+
out += identity
59+
return out
60+
61+
62+
class IResNet(nn.Module):
63+
fc_scale = 7 * 7
64+
65+
def __init__(self, block, layers, bn_norm, dropout=0, zero_init_residual=False,
66+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
67+
super().__init__()
68+
self.inplanes = 64
69+
self.dilation = 1
70+
self.fp16 = fp16
71+
if replace_stride_with_dilation is None:
72+
replace_stride_with_dilation = [False, False, False]
73+
if len(replace_stride_with_dilation) != 3:
74+
raise ValueError("replace_stride_with_dilation should be None "
75+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
76+
self.groups = groups
77+
self.base_width = width_per_group
78+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
79+
self.bn1 = get_norm(bn_norm, self.inplanes)
80+
self.prelu = nn.PReLU(self.inplanes)
81+
self.layer1 = self._make_layer(block, 64, layers[0], bn_norm, stride=2)
82+
self.layer2 = self._make_layer(block,
83+
128,
84+
layers[1],
85+
bn_norm,
86+
stride=2,
87+
dilate=replace_stride_with_dilation[0])
88+
self.layer3 = self._make_layer(block,
89+
256,
90+
layers[2],
91+
bn_norm,
92+
stride=2,
93+
dilate=replace_stride_with_dilation[1])
94+
self.layer4 = self._make_layer(block,
95+
512,
96+
layers[3],
97+
bn_norm,
98+
stride=2,
99+
dilate=replace_stride_with_dilation[2])
100+
self.bn2 = get_norm(bn_norm, 512 * block.expansion)
101+
self.dropout = nn.Dropout(p=dropout, inplace=True)
102+
103+
for m in self.modules():
104+
if isinstance(m, nn.Conv2d):
105+
nn.init.normal_(m.weight, 0, 0.1)
106+
elif m.__class__.__name__.find('Norm') != -1:
107+
nn.init.constant_(m.weight, 1)
108+
nn.init.constant_(m.bias, 0)
109+
110+
if zero_init_residual:
111+
for m in self.modules():
112+
if isinstance(m, IBasicBlock):
113+
nn.init.constant_(m.bn2.weight, 0)
114+
115+
def _make_layer(self, block, planes, blocks, bn_norm, stride=1, dilate=False):
116+
downsample = None
117+
previous_dilation = self.dilation
118+
if dilate:
119+
self.dilation *= stride
120+
stride = 1
121+
if stride != 1 or self.inplanes != planes * block.expansion:
122+
downsample = nn.Sequential(
123+
conv1x1(self.inplanes, planes * block.expansion, stride),
124+
get_norm(bn_norm, planes * block.expansion),
125+
)
126+
layers = []
127+
layers.append(
128+
block(self.inplanes, planes, bn_norm, stride, downsample, self.groups,
129+
self.base_width, previous_dilation))
130+
self.inplanes = planes * block.expansion
131+
for _ in range(1, blocks):
132+
layers.append(
133+
block(self.inplanes,
134+
planes,
135+
bn_norm,
136+
groups=self.groups,
137+
base_width=self.base_width,
138+
dilation=self.dilation))
139+
140+
return nn.Sequential(*layers)
141+
142+
def forward(self, x):
143+
x = self.conv1(x)
144+
x = self.bn1(x)
145+
x = self.prelu(x)
146+
x = self.layer1(x)
147+
x = self.layer2(x)
148+
x = self.layer3(x)
149+
x = self.layer4(x)
150+
x = self.bn2(x)
151+
x = self.dropout(x)
152+
return x
153+
154+
155+
@BACKBONE_REGISTRY.register()
156+
def build_iresnet_backbone(cfg):
157+
"""
158+
Create a IResNet instance from config.
159+
Returns:
160+
ResNet: a :class:`ResNet` instance.
161+
"""
162+
163+
# fmt: off
164+
bn_norm = cfg.MODEL.BACKBONE.NORM
165+
depth = cfg.MODEL.BACKBONE.DEPTH
166+
dropout = cfg.MODEL.BACKBONE.DROPOUT
167+
fp16 = cfg.SOLVER.AMP.ENABLED
168+
# fmt: on
169+
170+
num_blocks_per_stage = {
171+
'18x': [2, 2, 2, 2],
172+
'34x': [3, 4, 6, 3],
173+
'50x': [3, 4, 14, 3],
174+
'100x': [3, 13, 30, 3],
175+
'200x': [6, 26, 60, 6],
176+
}[depth]
177+
178+
model = IResNet(IBasicBlock, num_blocks_per_stage, bn_norm, dropout, fp16=fp16)
179+
return model

projects/FastFace/fastface/modeling/partial_fc.py

-17
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,6 @@ def __init__(
5252

5353
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
5454

55-
""" TODO: consider resume training
56-
if resume:
57-
try:
58-
self.weight: torch.Tensor = torch.load(self.weight_name)
59-
logging.info("softmax weight resume successfully!")
60-
except (FileNotFoundError, KeyError, IndexError):
61-
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
62-
logging.info("softmax weight resume fail!")
63-
64-
try:
65-
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
66-
logging.info("softmax weight mom resume successfully!")
67-
except (FileNotFoundError, KeyError, IndexError):
68-
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
69-
logging.info("softmax weight mom resume fail!")
70-
else:
71-
"""
7255
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
7356
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
7457
logger.info("softmax weight init successfully!")

0 commit comments

Comments
 (0)