-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaction_train.py
161 lines (131 loc) · 4.89 KB
/
action_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import os
import json
import torch.nn as nn
import torch.backends.cudnn as cudnn
import random
import numpy as np
import logging
import time
from torch import optim
from torch.utils.data import Dataset
from models.sgcn.st_gcn import Model
def train_logging(file_path):
logging.basicConfig(filename=file_path, level=logging.DEBUG,
format="%(asctime)s %(filename)s %(levelname)s %(message)s",
datefmt="%a %d %b %Y %H:%M:%S")
logging.debug('debug')
logging.info('info')
logging.warning('warning')
logging.error('Error')
logging.critical('critical')
# 1.load data
class MyDataset(Dataset):
def __init__(self, root_path, transform=None, target_transform=None):
self.root_path = root_path
self.transform = transform
self.target_transform = target_transform
self.list = os.listdir(self.root_path)
self.json_list = []
for i, f in enumerate(list):
if f[-4:] == 'json':
self.json_list.append(f)
total_list = []
with open(self.annot_file, encoding='utf-8') as annot:
self.result = json.load(annot)
def ground_truth_parser(self, json_path):
f = open(json_path, 'rb')
infos = json.load(f)
bbox_anno, kpts_anno = [], []
for info in infos:
xmin, ymin, width, height = info['bbox'] ###检测框的左上角坐标和高宽
box_name = info['box_name'] ###检测框的名称
bbox_anno.append({'name': box_name, 'xmin': xmin, 'ymin': ymin, 'width': width, 'height': height})
anno = {'keypoints': info['keypoints'], 'num_keypoints': 17,
'category_id': 1, 'id': info['id'], 'bbox': info['bbox'], 'area': info['area'],
'iscrowd': int(info['iscrowd'])} ###关键点的标注信息
kpts_anno.append(anno)
return bbox_anno, kpts_anno
def __len__(self):
return len(self.json_list)
def __getitem__(self, item):
json_name = self.json_list[item]
json_path = os.path.join(self.root_path, json_name)
bbox, kpts = self.ground_truth_parser(json_path)
keypoints = np.array(kpts[item]['keypoints'])
keypoints = keypoints.reshape(17, 3)
action = np.array(bbox[item]['name'])
return keypoints, action
def init_seed(seed=1):
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == '__main__':
# args define
device = 0
root_path = 'data/dataset/out11.json'
log_path = 'four_test_result.log'
total_epoch = 1000
in_channels = 3
num_class = 5
edge_importance_weighting = True
save_best = False
graph_args = {'layout': 'coco', 'strategy': 'spatial'}
# logging
train_logging(log_path)
# 1.load data
mydata = MyDataset(root_path)
trainset = torch.utils.data.DataLoader(
dataset=mydata,
batch_size=16,
shuffle=True,
num_workers=0,
worker_init_fn=init_seed
)
# 2.load model
model = Model(in_channels, num_class, graph_args, edge_importance_weighting)
devices = torch.device("cuda")
model.to(devices)
# 3.define loss
CELloss = nn.CrossEntropyLoss().cuda(device)
# 4.define optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=0.0004)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [500, 750], gamma=0.1, last_epoch=-1)
# 5.start train
model.train()
train_record = dict()
least_loss = 100.
for epo in range(total_epoch):
start = time.time()
total_loss = []
print(f'epoch:{epo}')
for step, (keypoints, label) in enumerate(trainset):
with torch.no_grad():
keypoints = keypoints.float().cuda(device)
label = label.cuda(device)
# forward
output = model(keypoints)
loss = CELloss(output, label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss.append(loss.data.item())
mean_epoch_loss = np.mean(total_loss)
logging.info(
f"\tEpoch: {epo + 1}/{total_epoch}\tcost: {time.time() - start:.4f}\tloss: {mean_epoch_loss:.4f}")
if mean_epoch_loss < least_loss:
logging.info(f'save epoch:{epo+1}')
print(f'save epoch:{epo+1}')
save_best = True
least_loss = mean_epoch_loss
once = '\tMean training loss: {:.4f}.'.format(mean_epoch_loss)
print(once)
# 7.save weights
if save_best == True:
state_dict = model.state_dict()
torch.save(state_dict, 'weights/five_layer.pt')