diff --git a/README.md b/README.md index eac6f3b..21f6bc1 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,10 @@ *** 最近跟新: +- 2021.05.19 更新基于DBnet的多语种文本检测。 +- 2021.05.01 更新CRNN 训练,解决了多gpu训练问题,更换成lmdb训练,需要将图片先转成lmdb(在script文件夹中有多进程将图片转成lmdb的代码),做了一些训练优化,模型结构更改(训练时使用名字中带lmdb的yaml文件),实际训练效果如下表。 +- 2021.03.26 更新CRNN 训练效果,代码整理后上传 +- 2021.03.06 更新CRNN backbone resnet 和 mobilev3 以及配置文件 - 2020.12.22 更新CRNN+CTCLoss+CenterLoss训练 - 2020.09.18 更新文本检测说明文档 - 2020.09.12 更新DB,pse,pan,sast,crnn训练测试代码和预训练模型 @@ -27,6 +31,17 @@ - [ ] 训练通用化ocr模型 - [ ] 结合chinese_lite进行部署 - [ ] 手机端部署 +*** +### crnn模型效果(实验中) +使用 MJSynth(MJ) 和 SynthText(ST) 训练,以batchsize=512训练,在以下数据集上测试: + +| 模型 |迭代次数| CUTE80 | IC03_867 |IC13_1015|IC13_857|IC15_1811|IC15_2077|IIIT5k_3000|SVT|SVTP|mean| +|-|-|-|-|-|-|-|-|-|-|-|-| +| resnet34+lstm+ctc |120000| 82.98| 91.92|90.93|91.59|73.10|67.98|90.16|85.16|78.29|83.56| +| mobilev3_large+lstm+ctc | 210000| 73.61| 92.50|90.34|91.59|74.82|68.89|87.56|83.46|77.20|82.21| +| mobilev3_small+lstm+ctc | 210000| 66.31| 90.77|88.76|91.13|73.66|69.52|88.80|84.54|72.24|80.64| + + *** ### 检测模型效果(实验中) @@ -93,6 +108,18 @@ +*** + +### Dbnet多语种文本检测效果 + +#### 生成数据集: + + +#### 公开数据集: + + + + *** ### 有问题及交流加微信 diff --git a/bg_img/1.jpg b/bg_img/1.jpg new file mode 100644 index 0000000..cab5298 Binary files /dev/null and b/bg_img/1.jpg differ diff --git a/bg_img/2.jpg b/bg_img/2.jpg new file mode 100644 index 0000000..7d077d5 Binary files /dev/null and b/bg_img/2.jpg differ diff --git a/bg_img/3.jpg b/bg_img/3.jpg new file mode 100644 index 0000000..da162d0 Binary files /dev/null and b/bg_img/3.jpg differ diff --git a/bg_img/4.jpg b/bg_img/4.jpg new file mode 100644 index 0000000..98edb26 Binary files /dev/null and b/bg_img/4.jpg differ diff --git a/bg_img/5.jpg b/bg_img/5.jpg new file mode 100644 index 0000000..a9e3616 Binary files /dev/null and b/bg_img/5.jpg differ diff --git a/bg_img/6.jpg b/bg_img/6.jpg new file mode 100644 index 0000000..3601bd5 Binary files /dev/null and b/bg_img/6.jpg differ diff --git a/bg_img/7.jpg b/bg_img/7.jpg new file mode 100644 index 0000000..fc01ae2 Binary files /dev/null and b/bg_img/7.jpg differ diff --git a/bg_img/8.jpg b/bg_img/8.jpg new file mode 100644 index 0000000..a94db1d Binary files /dev/null and b/bg_img/8.jpg differ diff --git a/bg_img/9.jpg b/bg_img/9.jpg new file mode 100644 index 0000000..103e601 Binary files /dev/null and b/bg_img/9.jpg differ diff --git "a/checkpoint/\346\226\260\345\273\272\346\226\207\346\234\254\346\226\207\346\241\243.txt" "b/checkpoint/\346\226\260\345\273\272\346\226\207\346\234\254\346\226\207\346\241\243.txt" deleted file mode 100644 index e69de29..0000000 diff --git a/config/det_DB_mobilev3.yaml b/config/det_DB_mobilev3.yaml index b286185..bee1d35 100644 --- a/config/det_DB_mobilev3.yaml +++ b/config/det_DB_mobilev3.yaml @@ -9,12 +9,12 @@ base: crop_shape: [640,640] shrink_ratio: 0.4 n_epoch: 1200 - start_val: 400 + start_val: 700 show_step: 20 checkpoints: ./checkpoint save_epoch: 100 - restore: True - restore_file : ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/DB_best.pth.tar + restore: False + restore_file : ./checkpoint/DB_best.pth.tar backbone: function: ptocr.model.backbone.det_mobilev3,mobilenet_v3_small @@ -82,6 +82,6 @@ postprocess: min_size: 3 infer: - model_path: './checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200/DB_best.pth.tar' + model_path: './checkpoint/DB_best.pth.tar' path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' save_path: './result' diff --git a/config/det_DB_mobilev3_common.yaml b/config/det_DB_mobilev3_common.yaml deleted file mode 100644 index 2de20ac..0000000 --- a/config/det_DB_mobilev3_common.yaml +++ /dev/null @@ -1,87 +0,0 @@ -base: - gpu_id: '2' - algorithm: DB - pretrained: True - in_channels: [24, 40, 48, 96] - inner_channels: 96 - k: 50 - adaptive: True - crop_shape: [640,640] - shrink_ratio: 0.4 - n_epoch: 400 - start_val: 500 - show_step: 20 - checkpoints: ./checkpoint - save_epoch: 1 - restore: False - restore_file : ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_400/DB_35.pth.tar - -backbone: - function: ptocr.model.backbone.det_mobilev3,mobilenet_v3_small - -head: - function: ptocr.model.head.det_DBHead,DB_Head -# function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head -# function: ptocr.model.head.det_FPNHead,FPN_Head - -segout: - function: ptocr.model.segout.det_DB_segout,SegDetector - -architectures: - model_function: ptocr.model.architectures.det_model,DetModel - loss_function: ptocr.model.architectures.det_model,DetLoss - -loss: - function: ptocr.model.loss.db_loss,DBLoss - l1_scale: 10 - bce_scale: 1 - -#optimizer: -# function: ptocr.optimizer,AdamDecay -# base_lr: 0.002 -# beta1: 0.9 -# beta2: 0.999 - -optimizer: - function: ptocr.optimizer,SGDDecay - base_lr: 0.002 - momentum: 0.99 - weight_decay: 0.00005 - -optimizer_decay: - function: ptocr.optimizer,adjust_learning_rate_poly - factor: 0.9 - -#optimizer_decay: -# function: ptocr.optimizer,adjust_learning_rate -# schedule: [1,2] -# gama: 0.1 - -trainload: - function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrain - train_file: /src/notebooks/chinese_recognize_data/detection/CommonData/train_list.txt - num_workers: 10 - batch_size: 16 - -testload: - function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest - test_file: /src/notebooks/detect_text/icdar2015/test_list.txt - test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ - test_size: 736 - stride: 32 - num_workers: 5 - batch_size: 4 - -postprocess: - function: ptocr.postprocess.DBpostprocess,DBPostProcess - is_poly: False - thresh: 0.2 - box_thresh: 0.4 - max_candidates: 1000 - unclip_ratio: 2 - min_size: 3 - -infer: - model_path: './checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200/DB_best.pth.tar' - path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' - save_path: './result' diff --git a/config/det_DB_mobilev3_pytorch_qua.yaml b/config/det_DB_mobilev3_pytorch_qua.yaml deleted file mode 100644 index 9206e8f..0000000 --- a/config/det_DB_mobilev3_pytorch_qua.yaml +++ /dev/null @@ -1,88 +0,0 @@ -base: - gpu_id: '2' - algorithm: DB - backend: 'qnnpack' # fbgemm - pretrained: False - in_channels: [24, 40, 48, 96] - inner_channels: 96 - k: 50 - adaptive: True - crop_shape: [640,640] - shrink_ratio: 0.4 - n_epoch: 1200 - start_val: 400 - show_step: 20 - checkpoints: ./checkpoint - save_epoch: 100 - restore: False - restore_file : ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/DB_best.pth.tar - -backbone: - function: ptocr.model.backbone.det_mobilev3_pytorch_qua,mobilenet_v3_small - -head: - function: ptocr.model.head.det_DBHead_Qua,DB_Head -# function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head -# function: ptocr.model.head.det_FPNHead,FPN_Head - -segout: - function: ptocr.model.segout.det_DB_segout_qua,SegDetector - -architectures: - model_function: ptocr.model.architectures.det_model_q,DetModel - loss_function: ptocr.model.architectures.det_model_q,DetLoss - -loss: - function: ptocr.model.loss.db_loss,DBLoss - l1_scale: 10 - bce_scale: 1 - -#optimizer: -# function: ptocr.optimizer,AdamDecay -# base_lr: 0.002 -# beta1: 0.9 -# beta2: 0.999 - -optimizer: - function: ptocr.optimizer,SGDDecay - base_lr: 0.002 - momentum: 0.99 - weight_decay: 0.00005 - -optimizer_decay: - function: ptocr.optimizer,adjust_learning_rate_poly - factor: 0.9 - -#optimizer_decay: -# function: ptocr.optimizer,adjust_learning_rate -# schedule: [1,2] -# gama: 0.1 - -trainload: - function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrain - train_file: /src/notebooks/detect_text/icdar2015/train_list.txt - num_workers: 10 - batch_size: 16 - -testload: - function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest - test_file: /src/notebooks/detect_text/icdar2015/test_list.txt - test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ - test_size: 736 - stride: 32 - num_workers: 5 - batch_size: 4 - -postprocess: - function: ptocr.postprocess.DBpostprocess,DBPostProcess - is_poly: False - thresh: 0.5 - box_thresh: 0.6 - max_candidates: 1000 - unclip_ratio: 2 - min_size: 3 - -infer: - model_path: './checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200/DB_best.pth.tar' - path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' - save_path: './result' diff --git a/config/det_DB_resnet50_3_3.yaml b/config/det_DB_resnet50_3_3.yaml index 8b117b5..ff52b4a 100644 --- a/config/det_DB_resnet50_3_3.yaml +++ b/config/det_DB_resnet50_3_3.yaml @@ -1,15 +1,15 @@ base: gpu_id: '0' algorithm: DB - pretrained: True + pretrained: False in_channels: [256, 512, 1024, 2048] inner_channels: 256 k: 50 adaptive: True crop_shape: [640,640] shrink_ratio: 0.4 - n_epoch: 1201 - start_val: 400 + n_epoch: 600 + start_val: 6000 show_step: 20 checkpoints: ./checkpoint save_epoch: 100 @@ -17,11 +17,11 @@ base: restore_file : ./DB.pth.tar backbone: - function: ptocr.model.backbone.det_resnet_3*3,resnet50 + function: ptocr.model.backbone.det_resnet_3_3,resnet50 head: - function: ptocr.model.head.det_DBHead,DB_Head -# function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head +# function: ptocr.model.head.det_DBHead,DB_Head + function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head # function: ptocr.model.head.det_FPNHead,FPN_Head segout: @@ -59,7 +59,7 @@ optimizer_decay: trainload: function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrain - train_file: /src/notebooks/detect_text/icdar2015/train_list.txt + train_file: /src/notebooks/MyworkData/huayandang/train_list.txt num_workers: 10 batch_size: 8 @@ -75,10 +75,10 @@ testload: postprocess: function: ptocr.postprocess.DBpostprocess,DBPostProcess is_poly: False - thresh: 0.5 - box_thresh: 0.6 + thresh: 0.2 + box_thresh: 0.3 max_candidates: 1000 - unclip_ratio: 2 + unclip_ratio: 1.5 min_size: 3 infer: diff --git a/config/det_DB_resnet50_mul.yaml b/config/det_DB_resnet50_mul.yaml new file mode 100644 index 0000000..e37dcc4 --- /dev/null +++ b/config/det_DB_resnet50_mul.yaml @@ -0,0 +1,89 @@ +base: + gpu_id: '1' # 设置训练的gpu id,多卡训练设置为 '0,1,2' + algorithm: DB # 算法名称 + pretrained: True # 是否加载预训练 + in_channels: [256, 512, 1024, 2048] # + inner_channels: 256 # + k: 50 + n_class: 3 + adaptive: True + crop_shape: [640,640] #训练时crop图片的大小 + shrink_ratio: 0.4 # kernel向内收缩比率 + n_epoch: 1200 # 训练的epoch + start_val: 400 #开始验证的epoch,如果不想验证直接设置数值大于n_epoch + show_step: 20 #设置迭代多少次输出一次loss + checkpoints: ./checkpoint #保存模型地址 + save_epoch: 100 #设置每多少个epoch保存一次模型 + restore: False #是否恢复训练 + restore_file : ./DB.pth.tar #恢复训练所需加载模型的地址 + +backbone: + function: ptocr.model.backbone.det_resnet,resnet50 + +head: + function: ptocr.model.head.det_DBHead,DB_Head +# function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head +# function: ptocr.model.head.det_FPNHead,FPN_Head + +segout: + function: ptocr.model.segout.det_DB_segout,SegDetectorMul + +architectures: + model_function: ptocr.model.architectures.det_model,DetModel + loss_function: ptocr.model.architectures.det_model,DetLoss + +loss: + function: ptocr.model.loss.db_loss,DBLossMul + l1_scale: 10 + bce_scale: 1 + class_scale: 1 + +#optimizer: +# function: ptocr.optimizer,AdamDecay +# base_lr: 0.002 +# beta1: 0.9 +# beta2: 0.999 + +optimizer: + function: ptocr.optimizer,SGDDecay + base_lr: 0.002 + momentum: 0.99 + weight_decay: 0.0005 + +optimizer_decay: + function: ptocr.optimizer,adjust_learning_rate_poly + factor: 0.9 + +#optimizer_decay: +# function: ptocr.optimizer,adjust_learning_rate +# schedule: [1,2] +# gama: 0.1 + +trainload: + function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrainMul + train_file: /src/notebooks/fangxuwei_96/TextGenerator-master/output/train/train_list.txt + num_workers: 10 + batch_size: 8 + +testload: + function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest + test_file: /src/notebooks/detect_text/icdar2015/test_list.txt + test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ + test_size: 736 + stride: 32 + num_workers: 5 + batch_size: 4 + +postprocess: + function: ptocr.postprocess.DBpostprocess,DBPostProcessMul + is_poly: False #测试时,检测弯曲文本设置成 True,否则就是输出矩形框 + thresh: 0.5 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 2 + min_size: 3 + +infer: + model_path: './checkpoint/ag_DB_bb_resnet50_he_DB_Head_bs_8_ep_601_train_mul/DB_best.pth.tar' + path: '/src/notebooks/fangxuwei_96/TextGenerator-master/output/img/' + save_path: './result' diff --git a/config/det_SAST_resnet50_3_3_ori_dataload.yaml b/config/det_SAST_resnet50_3_3_ori_dataload.yaml index 89c4660..a784869 100644 --- a/config/det_SAST_resnet50_3_3_ori_dataload.yaml +++ b/config/det_SAST_resnet50_3_3_ori_dataload.yaml @@ -5,7 +5,7 @@ base: with_attention: True crop_shape: [512,512] n_epoch: 901 - start_val: 500 + start_val: 5000 show_step: 20 checkpoints: ./checkpoint save_epoch: 100 @@ -13,7 +13,7 @@ base: restore_file : ./checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_12_ep_2000/SAST_best.pth.tar backbone: - function: ptocr.model.backbone.det_resnet_sast_3*3,resnet50 + function: ptocr.model.backbone.det_resnet_sast_3_3,resnet50 head: function: ptocr.model.head.det_SASTHead,SASTHead @@ -67,7 +67,7 @@ optimizer_decay: trainload: function: ptocr.dataloader.DetLoad.SASTProcess_ori,SASTProcessTrain - train_file: /src/notebooks/detect_text/icdar2015/train_list.txt + train_file: /src/notebooks/MyworkData/huayandang/train_list.txt num_workers: 12 batch_size: 8 min_crop_side_ratio: 0.3 @@ -95,6 +95,6 @@ postprocess: tcl_map_thresh: 0.7 infer: - model_path: './checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_8_ep_1000/SAST_best.pth.tar' - path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' + model_path: './checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_8_ep_901/SAST_400.pth.tar' + path: '/src/notebooks/MyworkData/huayandang/train' save_path: './result' diff --git a/config/rec_CRNN_mobilev3.yaml b/config/rec_CRNN_mobilev3.yaml deleted file mode 100644 index 3ce7a17..0000000 --- a/config/rec_CRNN_mobilev3.yaml +++ /dev/null @@ -1,78 +0,0 @@ -base: - gpu_id: '0' - algorithm: CRNN - pretrained: True - inchannel: 96 - hiddenchannel: 48 - img_shape: [32,280] - is_gray: True - use_conv: False - use_attention: False - use_lstm: False - lstm_num: 2 - classes: 1000 - n_epoch: 20 - start_val: 1 - show_step: 20 - checkpoints: ./checkpoint - save_epoch: 1 - show_num: 10 - restore: False - restore_file : ./DB.pth.tar - -backbone: - function: ptocr.model.backbone.reg_mobilev3,mobilenet_v3_small - -head: - function: ptocr.model.head.rec_CRNNHead,CRNN_Head - -architectures: - model_function: ptocr.model.architectures.rec_model,RecModel - loss_function: ptocr.model.architectures.rec_model,RecLoss - -loss: - function: ptocr.model.loss.ctc_loss,CTCLoss - reduction: 'mean' - -optimizer: - function: ptocr.optimizer,AdamDecay - base_lr: 0.001 - beta1: 0.9 - beta2: 0.999 - weight_decay: 0.00005 - -# optimizer: -# function: ptocr.optimizer,SGDDecay -# base_lr: 0.002 -# momentum: 0.99 -# weight_decay: 0.00005 - -# optimizer_decay: -# function: ptocr.optimizer,adjust_learning_rate_poly -# factor: 0.9 - -optimizer_decay: - function: ptocr.optimizer,adjust_learning_rate - schedule: [5,8,11,14] - gama: 0.1 - -trainload: - function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTrain - train_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_file/train.txt - key_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_file/key.txt - num_workers: 10 - batch_size: 256 - -testload: - function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTest - test_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_file/val.txt - num_workers: 5 - batch_size: 256 - -label_transform: - function: ptocr.utils.transform_label,strLabelConverter - -infer: - model_path: '' - path: '' - save_path: '' diff --git a/config/rec_CRNN_mobilev3_large_english_all.yaml b/config/rec_CRNN_mobilev3_large_english_all.yaml new file mode 100644 index 0000000..7cab74a --- /dev/null +++ b/config/rec_CRNN_mobilev3_large_english_all.yaml @@ -0,0 +1,102 @@ +base: + gpu_id: '0,1' + algorithm: CRNN + pretrained: False + inchannel: 960 + hiddenchannel: 96 + img_shape: [32,100] + is_gray: True + use_conv: False + use_attention: False + use_lstm: True + lstm_num: 2 + classes: 1000 + max_iters: 300000 + eval_iter: 10000 + show_step: 100 + checkpoints: ./checkpoint + save_epoch: 1 + show_num: 10 + restore: False + finetune: False + restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar + +backbone: + function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_large + +head: + function: ptocr.model.head.rec_CRNNHead,CRNN_Head + +architectures: + model_function: ptocr.model.architectures.rec_model,RecModel + loss_function: ptocr.model.architectures.rec_model,RecLoss + +loss: + function: ptocr.model.loss.ctc_loss,CTCLoss + use_ctc_weight: False + reduction: 'mean' + center_function: ptocr.model.loss.centerloss,CenterLoss + use_center: False + center_lr: 0.5 + label_score: 0.95 +# min_score: 0.01 + weight_center: 0.000001 + + +optimizer: + function: ptocr.optimizer,AdamDecay + base_lr: 0.001 + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.00005 + +# optimizer: +# function: ptocr.optimizer,SGDDecay +# base_lr: 0.002 +# momentum: 0.99 +# weight_decay: 0.00005 + +# optimizer_decay: +# function: ptocr.optimizer,adjust_learning_rate_poly +# factor: 0.9 + +optimizer_decay: + function: ptocr.optimizer,adjust_learning_rate + schedule: [100000,200000] + gama: 0.1 + +optimizer_decay_center: + function: ptocr.optimizer,adjust_learning_rate_center + schedule: [100000,200000] + gama: 0.1 + +trainload: + function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad + train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth'] + batch_ratio: [0.5,0.5] + key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt + bg_path: ./bg_img/ + num_workers: 16 + batch_size: 512 + +testload: + function: ptocr.dataloader.RecLoad.CRNNProcess1,CRNNProcessTest + test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt + num_workers: 8 + batch_size: 256 + + +label_transform: + function: ptocr.utils.transform_label,strLabelConverter + +transform: + function: ptocr.dataloader.RecLoad.DataAgument,transform_label + t_type: lower + char_type: En + +infer: +# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' + model_path: './checkpoint/ag_CRNN_bb_mobilenet_v3_large_he_CRNN_Head_bs_512_ep_300000_mobilev2_alldata/CRNN_210000.pth.tar' +# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' + path: './english_val_img/' + save_path: '' diff --git a/config/rec_CRNN_mobilev3_large_english_lmdb.yaml b/config/rec_CRNN_mobilev3_large_english_lmdb.yaml new file mode 100644 index 0000000..9479d4a --- /dev/null +++ b/config/rec_CRNN_mobilev3_large_english_lmdb.yaml @@ -0,0 +1,77 @@ +base: + gpu_id: '0' + algorithm: CRNN + pretrained: False + inchannel: 960 + hiddenchannel: 96 + img_shape: [32,100] + is_gray: True + use_attention: False + use_lstm: True + lstm_num: 2 + n_epoch: 8 + start_val: 0 + show_step: 50 + checkpoints: ./checkpoint + save_epoch: 1 + show_num: 10 + restore: False + finetune: False + restore_file : ./checkpoint/ + +backbone: + function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_large + +head: + function: ptocr.model.head.rec_CRNNHead,CRNN_Head + +architectures: + model_function: ptocr.model.architectures.rec_model,RecModel + loss_function: ptocr.model.architectures.rec_model,RecLoss + +loss: + function: ptocr.model.loss.ctc_loss,CTCLoss + ctc_type: 'warpctc' # torchctc + use_ctc_weight: False + loss_title: ['ctc_loss'] + +optimizer: + function: ptocr.optimizer,AdamDecay + base_lr: 0.001 + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.00005 + + +optimizer_decay: + function: ptocr.optimizer,adjust_learning_rate + schedule: [4,6] + gama: 0.1 + + +trainload: + function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad + train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/' + key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt + bg_path: ./bg_img/ + num_workers: 10 + batch_size: 512 + +valload: + function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad + val_file: '/src/notebooks/IIIT5k_3000/lmdb/' + num_workers: 5 + batch_size: 256 + +label_transform: + function: ptocr.utils.transform_label,strLabelConverter + label_function: ptocr.dataloader.RecLoad.DataAgument,transform_label + t_type: lower + char_type: En + +infer: +# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' + model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar' +# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' + path: './english_val_img/SVT/image/' + save_path: '' diff --git a/config/rec_CRNN_mobilev3_small_english_all.yaml b/config/rec_CRNN_mobilev3_small_english_all.yaml new file mode 100644 index 0000000..4ec984f --- /dev/null +++ b/config/rec_CRNN_mobilev3_small_english_all.yaml @@ -0,0 +1,104 @@ +base: + gpu_id: '1' + algorithm: CRNN + pretrained: False + inchannel: 576 + hiddenchannel: 48 + img_shape: [32,100] + is_gray: True + use_conv: False + use_attention: False + use_lstm: True + lstm_num: 2 + classes: 1000 + max_iters: 300000 + eval_iter: 10000 + show_step: 100 + checkpoints: ./checkpoint + save_epoch: 1 + show_num: 10 + restore: False + finetune: False + restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar + +backbone: + function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_small + +head: + function: ptocr.model.head.rec_CRNNHead,CRNN_Head + +architectures: + model_function: ptocr.model.architectures.rec_model,RecModel + loss_function: ptocr.model.architectures.rec_model,RecLoss + +loss: + function: ptocr.model.loss.ctc_loss,CTCLoss + use_ctc_weight: False + reduction: 'mean' + center_function: ptocr.model.loss.centerloss,CenterLoss + use_center: False + center_lr: 0.5 + label_score: 0.95 +# min_score: 0.01 + weight_center: 0.000001 + + +optimizer: + function: ptocr.optimizer,AdamDecay + base_lr: 0.001 + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.00005 + +# optimizer: +# function: ptocr.optimizer,SGDDecay +# base_lr: 0.002 +# momentum: 0.99 +# weight_decay: 0.00005 + +# optimizer_decay: +# function: ptocr.optimizer,adjust_learning_rate_poly +# factor: 0.9 + +optimizer_decay: + function: ptocr.optimizer,adjust_learning_rate + schedule: [100000,200000] + gama: 0.1 + +optimizer_decay_center: + function: ptocr.optimizer,adjust_learning_rate_center + schedule: [100000,200000] + gama: 0.1 + +trainload: + function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad + train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth'] + batch_ratio: [0.5,0.5] + key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt + bg_path: ./bg_img/ + num_workers: 16 + batch_size: 512 + +valload: + function: ptocr.dataloader.RecLoad.CRNNProcess1,GetValDataLoad + root: '/src/notebooks/pytorchOCR-master/english_val_img' + dir: ['CUTE80','IC03_867','IC13_1015','IC13_857','IC15_1811','IIIT5k_3000','SVT','SVTP','IC15_2077'] + test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt + num_workers: 2 + batch_size: 1 + + +label_transform: + function: ptocr.utils.transform_label,strLabelConverter + +transform: + function: ptocr.dataloader.RecLoad.DataAgument,transform_label + t_type: lower + char_type: En + +infer: +# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' + model_path: './checkpoint/ag_CRNN_bb_mobilenet_v3_small_he_CRNN_Head_bs_512_ep_300000_mobilev2_small_alldata/CRNN_210000.pth.tar' +# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' + path: './english_val_img/' + save_path: '' diff --git a/config/rec_CRNN_mobilev3_small_english_lmdb.yaml b/config/rec_CRNN_mobilev3_small_english_lmdb.yaml new file mode 100644 index 0000000..7cf838b --- /dev/null +++ b/config/rec_CRNN_mobilev3_small_english_lmdb.yaml @@ -0,0 +1,77 @@ +base: + gpu_id: '0' + algorithm: CRNN + pretrained: False + inchannel: 576 + hiddenchannel: 48 + img_shape: [32,100] + is_gray: True + use_attention: False + use_lstm: True + lstm_num: 2 + n_epoch: 8 + start_val: 0 + show_step: 50 + checkpoints: ./checkpoint + save_epoch: 1 + show_num: 10 + restore: False + finetune: False + restore_file : ./checkpoint/ + +backbone: + function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_small + +head: + function: ptocr.model.head.rec_CRNNHead,CRNN_Head + +architectures: + model_function: ptocr.model.architectures.rec_model,RecModel + loss_function: ptocr.model.architectures.rec_model,RecLoss + +loss: + function: ptocr.model.loss.ctc_loss,CTCLoss + ctc_type: 'warpctc' # torchctc + use_ctc_weight: False + loss_title: ['ctc_loss'] + +optimizer: + function: ptocr.optimizer,AdamDecay + base_lr: 0.001 + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.00005 + + +optimizer_decay: + function: ptocr.optimizer,adjust_learning_rate + schedule: [4,6] + gama: 0.1 + + +trainload: + function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad + train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/' + key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt + bg_path: ./bg_img/ + num_workers: 10 + batch_size: 512 + +valload: + function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad + val_file: '/src/notebooks/IIIT5k_3000/lmdb/' + num_workers: 5 + batch_size: 256 + +label_transform: + function: ptocr.utils.transform_label,strLabelConverter + label_function: ptocr.dataloader.RecLoad.DataAgument,transform_label + t_type: lower + char_type: En + +infer: +# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' + model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar' +# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' + path: './english_val_img/SVT/image/' + save_path: '' diff --git a/config/rec_CRNN_resnet34_english_lmdb.yaml b/config/rec_CRNN_resnet34_english_lmdb.yaml new file mode 100644 index 0000000..f5da317 --- /dev/null +++ b/config/rec_CRNN_resnet34_english_lmdb.yaml @@ -0,0 +1,77 @@ +base: + gpu_id: '0' + algorithm: CRNN + pretrained: False + inchannel: 512 + hiddenchannel: 128 + img_shape: [32,100] + is_gray: True + use_attention: False + use_lstm: True + lstm_num: 2 + n_epoch: 8 + start_val: 0 + show_step: 50 + checkpoints: ./checkpoint + save_epoch: 1 + show_num: 10 + restore: True + finetune: False + restore_file : ./checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_no_attention_no_weight/CRNN_best.pth.tar + +backbone: + function: ptocr.model.backbone.reg_resnet_bd,resnet34 + +head: + function: ptocr.model.head.rec_CRNNHead,CRNN_Head + +architectures: + model_function: ptocr.model.architectures.rec_model,RecModel + loss_function: ptocr.model.architectures.rec_model,RecLoss + +loss: + function: ptocr.model.loss.ctc_loss,CTCLoss + ctc_type: 'warpctc' # torchctc + use_ctc_weight: False + loss_title: ['ctc_loss'] + +optimizer: + function: ptocr.optimizer,AdamDecay + base_lr: 0.001 + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.00005 + + +optimizer_decay: + function: ptocr.optimizer,adjust_learning_rate + schedule: [4,6] + gama: 0.1 + + +trainload: + function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad + train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/' + key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt + bg_path: ./bg_img/ + num_workers: 16 + batch_size: 512 + +valload: + function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad + val_file: '/src/notebooks/IIIT5k_3000/lmdb/' + num_workers: 5 + batch_size: 256 + +label_transform: + function: ptocr.utils.transform_label,strLabelConverter + label_function: ptocr.dataloader.RecLoad.DataAgument,transform_label + t_type: lower + char_type: En + +infer: +# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' + model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar' +# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' + path: './english_val_img/SVT/image/' + save_path: '' diff --git a/config/rec_CRNN_ori.yaml b/config/rec_CRNN_resnet_english.yaml similarity index 59% rename from config/rec_CRNN_ori.yaml rename to config/rec_CRNN_resnet_english.yaml index c2264fe..fc51a28 100644 --- a/config/rec_CRNN_ori.yaml +++ b/config/rec_CRNN_resnet_english.yaml @@ -1,28 +1,28 @@ base: - gpu_id: '1' + gpu_id: '0,1' algorithm: CRNN pretrained: False inchannel: 512 hiddenchannel: 128 - img_shape: [32,200] + img_shape: [32,100] is_gray: True use_conv: False use_attention: False - use_lstm: False - lstm_num: 1 + use_lstm: True + lstm_num: 2 classes: 1000 - n_epoch: 20 + n_epoch: 8 start_val: 0 - show_step: 20 + show_step: 100 checkpoints: ./checkpoint save_epoch: 1 show_num: 10 - restore: False - finetune: False - restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_test_center_3/CRNN_best_ori.pth.tar + restore: True + finetune: True + restore_file : ./checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_256_ep_20_no_channel_timestep_rnn/CRNN_best.pth.tar backbone: - function: ptocr.model.backbone.rec_crnn_backbone,rec_crnn_backbone + function: ptocr.model.backbone.reg_resnet_bd,resnet34 head: function: ptocr.model.head.rec_CRNNHead,CRNN_Head @@ -33,13 +33,14 @@ architectures: loss: function: ptocr.model.loss.ctc_loss,CTCLoss - reduction: 'sum' + use_ctc_weight: True + reduction: 'none' center_function: ptocr.model.loss.centerloss,CenterLoss - use_center: False + use_center: True center_lr: 0.5 label_score: 0.95 # min_score: 0.01 - weight_center: 0.000001 + weight_center: 0.001 optimizer: @@ -61,24 +62,25 @@ optimizer: optimizer_decay: function: ptocr.optimizer,adjust_learning_rate - schedule: [5,10,15] + schedule: [4,6] gama: 0.1 optimizer_decay_center: function: ptocr.optimizer,adjust_learning_rate_center - schedule: [6,10,15] + schedule: [4,6] gama: 0.1 trainload: - function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTrain - train_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_center/train.txt - key_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_center/key.txt + function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTrainLmdb + train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/' + key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt + bg_path: ./bg_img/ num_workers: 10 - batch_size: 256 + batch_size: 512 testload: function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTest - test_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_center/val.txt + test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt num_workers: 5 batch_size: 256 @@ -87,7 +89,8 @@ label_transform: function: ptocr.utils.transform_label,strLabelConverter infer: - model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_test_center_3/CRNN_best.pth.tar' -# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_no_center/CRNN_best.pth.tar' - path: '/src/notebooks/fangxuwei_96/crnn-master/train_data/data/gen_data_sim/center_test/default' +# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' + model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar' +# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' + path: './english_val_img/SVT/image/' save_path: '' diff --git a/config/rec_CRNN_resnet_english_all.yaml b/config/rec_CRNN_resnet_english_all.yaml new file mode 100644 index 0000000..b0e3771 --- /dev/null +++ b/config/rec_CRNN_resnet_english_all.yaml @@ -0,0 +1,102 @@ +base: + gpu_id: '0,1' + algorithm: CRNN + pretrained: False + inchannel: 512 + hiddenchannel: 256 + img_shape: [32,100] + is_gray: True + use_conv: False + use_attention: False + use_lstm: True + lstm_num: 2 + classes: 1000 + max_iters: 200000 + eval_iter: 10000 + show_step: 100 + checkpoints: ./checkpoint + save_epoch: 1 + show_num: 10 + restore: False + finetune: False + restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar + +backbone: + function: ptocr.model.backbone.reg_resnet_bd,resnet34 + +head: + function: ptocr.model.head.rec_CRNNHead,CRNN_Head + +architectures: + model_function: ptocr.model.architectures.rec_model,RecModel + loss_function: ptocr.model.architectures.rec_model,RecLoss + +loss: + function: ptocr.model.loss.ctc_loss,CTCLoss + use_ctc_weight: False + reduction: 'mean' + center_function: ptocr.model.loss.centerloss,CenterLoss + use_center: False + center_lr: 0.5 + label_score: 0.95 +# min_score: 0.01 + weight_center: 0.000001 + + +optimizer: + function: ptocr.optimizer,AdamDecay + base_lr: 0.001 + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.00005 + +# optimizer: +# function: ptocr.optimizer,SGDDecay +# base_lr: 0.002 +# momentum: 0.99 +# weight_decay: 0.00005 + +# optimizer_decay: +# function: ptocr.optimizer,adjust_learning_rate_poly +# factor: 0.9 + +optimizer_decay: + function: ptocr.optimizer,adjust_learning_rate + schedule: [80000,160000] + gama: 0.1 + +optimizer_decay_center: + function: ptocr.optimizer,adjust_learning_rate_center + schedule: [80000,160000] + gama: 0.1 + +trainload: + function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad + train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth'] + batch_ratio: [0.5,0.5] + key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt + bg_path: ./bg_img/ + num_workers: 16 + batch_size: 512 + +testload: + function: ptocr.dataloader.RecLoad.CRNNProcess1,CRNNProcessTest + test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt + num_workers: 8 + batch_size: 256 + + +label_transform: + function: ptocr.utils.transform_label,strLabelConverter + +transform: + function: ptocr.dataloader.RecLoad.DataAgument,transform_label + t_type: lower + char_type: En + +infer: +# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' + model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_200000_alldata/CRNN_120000.pth.tar' +# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' + path: './english_val_img/' + save_path: '' diff --git a/config/rec_CRNN_vgg16_bn.yaml b/config/rec_CRNN_vgg16_bn.yaml deleted file mode 100644 index e2f4747..0000000 --- a/config/rec_CRNN_vgg16_bn.yaml +++ /dev/null @@ -1,78 +0,0 @@ -base: - gpu_id: '1' - algorithm: CRNN - pretrained: True - inchannel: 512 - hiddenchannel: 128 - img_shape: [32,100] - is_gray: True - use_conv: False - use_attention: False - use_lstm: False - lstm_num: 2 - classes: 1000 - n_epoch: 100 - start_val: 10 - show_step: 20 - checkpoints: ./checkpoint - save_epoch: 1 - show_num: 10 - restore: False - restore_file : ./DB.pth.tar - -backbone: - function: ptocr.model.backbone.rec_vgg,vgg16_bn - -head: - function: ptocr.model.head.rec_CRNNHead,CRNN_Head - -architectures: - model_function: ptocr.model.architectures.rec_model,RecModel - loss_function: ptocr.model.architectures.rec_model,RecLoss - -loss: - function: ptocr.model.loss.ctc_loss,CTCLoss - reduction: 'mean' - -optimizer: - function: ptocr.optimizer,AdamDecay - base_lr: 0.001 - beta1: 0.9 - beta2: 0.999 - weight_decay: 0.00005 - -# optimizer: -# function: ptocr.optimizer,SGDDecay -# base_lr: 0.002 -# momentum: 0.99 -# weight_decay: 0.00005 - -optimizer_decay: - function: ptocr.optimizer,adjust_learning_rate_poly - factor: 0.9 - -#optimizer_decay: -# function: ptocr.optimizer,adjust_learning_rate -# schedule: [1,2] -# gama: 0.1 - -trainload: - function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTrain - train_file: /src/notebooks/detect_text/icdar2015/recognize/train_list.txt - key_file: /src/notebooks/detect_text/icdar2015/recognize/key.txt - num_workers: 10 - batch_size: 32 - -testload: - function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTest - test_file: /src/notebooks/detect_text/icdar2015/recognize/test_list.txt - num_workers: 5 - batch_size: 32 - -label_transform: - function: ptocr.utils.transform_label,strLabelConverter - -infer: - model_path: '' - path: '' - save_path: '' diff --git a/config/rec_FC_resnet_english_all.yaml b/config/rec_FC_resnet_english_all.yaml new file mode 100644 index 0000000..2a769de --- /dev/null +++ b/config/rec_FC_resnet_english_all.yaml @@ -0,0 +1,107 @@ +base: + gpu_id: '0,1' + algorithm: FC + pretrained: False + in_channels: 2048 + out_channels: 1024 + ignore_index: 37 + max_length: 25 + img_shape: [32,100] + is_gray: True + use_conv: False + use_attention: False + use_lstm: True + lstm_num: 2 + num_class: 36 + start_iters: 0 + max_iters: 300000 + eval_iter: 10000 + show_step: 100 + checkpoints: ./checkpoint + save_epoch: 1 + show_num: 10 + restore: False + finetune: False + restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar + +backbone: + function: ptocr.model.backbone.reg_resnet_bd,resnet50 + +head: + function: ptocr.model.head.rec_FCHead,FC_Head + +architectures: + model_function: ptocr.model.architectures.rec_model,RecModel + loss_function: ptocr.model.architectures.rec_model,RecLoss + +loss: + function: ptocr.model.loss.fc_loss,FCLoss + use_ctc_weight: False + reduction: 'mean' + center_function: ptocr.model.loss.centerloss,CenterLoss + use_center: False + center_lr: 0.5 + label_score: 0.95 +# min_score: 0.01 + weight_center: 0.000001 + + +optimizer: + function: ptocr.optimizer,AdamDecay + base_lr: 0.001 + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.00005 + +# optimizer: +# function: ptocr.optimizer,SGDDecay +# base_lr: 0.002 +# momentum: 0.99 +# weight_decay: 0.00005 + +# optimizer_decay: +# function: ptocr.optimizer,adjust_learning_rate_poly +# factor: 0.9 + +optimizer_decay: + function: ptocr.optimizer,adjust_learning_rate + schedule: [100000,200000] + gama: 0.1 + +optimizer_decay_center: + function: ptocr.optimizer,adjust_learning_rate_center + schedule: [80000,160000] + gama: 0.1 + +trainload: + function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad + train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth'] + batch_ratio: [0.5,0.5] + key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt + bg_path: ./bg_img/ + num_workers: 16 + batch_size: 256 + +valload: + function: ptocr.dataloader.RecLoad.CRNNProcess1,GetValDataLoad + root: '/src/notebooks/pytorchOCR-master/english_val_img' + dir: ['CUTE80','IC03_867','IC13_1015','IC13_857','IC15_1811','IIIT5k_3000','SVT','SVTP','IC15_2077'] + test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt + num_workers: 2 + batch_size: 1 + + +label_transform: + function: ptocr.utils.transform_label,FCConverter + +transform: + function: ptocr.dataloader.RecLoad.DataAgument,transform_label + t_type: lower + char_type: En + +infer: +# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' + model_path: './checkpoint/ag_FC_bb_resnet34_he_FC_Head_bs_128_ep_200000_FC/FC_190000.pth.tar' +# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' + path: './english_val_img/' + save_path: '' diff --git "a/doc/md/\346\226\207\346\234\254\350\257\206\345\210\253\350\256\255\347\273\203\346\226\207\346\241\243.md" "b/doc/md/\346\226\207\346\234\254\350\257\206\345\210\253\350\256\255\347\273\203\346\226\207\346\241\243.md" index 6a9ca23..da5a0fc 100644 --- "a/doc/md/\346\226\207\346\234\254\350\257\206\345\210\253\350\256\255\347\273\203\346\226\207\346\241\243.md" +++ "b/doc/md/\346\226\207\346\234\254\350\257\206\345\210\253\350\256\255\347\273\203\346\226\207\346\241\243.md" @@ -4,24 +4,16 @@ 需要一个train_list.txt[示例](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/doc/example/rec_train_list.txt) , 格式:图片绝对路径+\t+label。 具体可参照项目中data/example中例子。 如果训练过程中需要做验证,需要制作相同的数据格式有一个test_list.txt[示例](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/doc/example/rec_test_list.txt)。 -#### 正常训练模型(以rec_CRNN_ori.yaml为例) -1. 将yaml中base下的restore,finetune置为False,loss下的use_center置为False,修改数据路径还有其他参数。 -2. 运行下面命令 +#### 训练模型 +1. 修改./config中对应算法的yaml中参数,基本上只需修改数据路径即可。 +2. 在./tools/rec_train.py最下面打开不同的config中的yaml对应不同的算法 +3. 运行下面命令 ``` -python3 ./tools/rec_train.py --config ./config/rec_CRNN_ori.yaml --log_str log - - -#### CenterLoss训练模型(以rec_CRNN_ori.yaml为例) -1. 将yaml中base下的restore,finetune置为True,loss下的use_center置为True,将正常训练得到的最优模型文件地址赋给base下的restore_file。 -2. 运行下面命令 - -``` -python3 ./tools/rec_train.py --config ./config/rec_CRNN_ori.yaml --log_str log +python3 ./tools/rec_train.py ``` #### 测试模型 -1. 将训练好的模型赋给yaml中infer下的model_path,图片地址赋给path -2. 运行下面命令 +1. 运行下面命令 ``` python3 ./tools/rec_infer.py diff --git a/doc/show/1.jpg b/doc/show/1.jpg new file mode 100644 index 0000000..b8a957b Binary files /dev/null and b/doc/show/1.jpg differ diff --git a/doc/show/2.jpg b/doc/show/2.jpg new file mode 100644 index 0000000..bebb5cf Binary files /dev/null and b/doc/show/2.jpg differ diff --git a/doc/show/3.jpg b/doc/show/3.jpg new file mode 100644 index 0000000..3e80900 Binary files /dev/null and b/doc/show/3.jpg differ diff --git a/finetune_prune_model.sh b/finetune_prune_model.sh deleted file mode 100644 index 6728b83..0000000 --- a/finetune_prune_model.sh +++ /dev/null @@ -1 +0,0 @@ -python3 tools/det_train.py --config ./config/det_DB_mobilev3.yaml --log_str total_prune_20201015_distil3 --pruned_model_dict_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/pruned/pruned_dict.dict --prune_model_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/pruned/pruned_dict.pth --prune_type total --n_epoch 200 --start_val 30 --base_lr 0.0008 --gpu_id 2 --t_ratio 0.1 --t_model_path ./checkpoint/ag_DB_bb_resnet50_he_DB_Head_bs_8_ep_1201/DB_best.pth.tar --t_config ./config/det_DB_resnet50_3_3.yaml \ No newline at end of file diff --git a/infer.sh b/infer.sh deleted file mode 100644 index c111f04..0000000 --- a/infer.sh +++ /dev/null @@ -1 +0,0 @@ -python3 ./tools/det_infer.py --config ./config/det_DB_mobilev3.yaml --model_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200/DB_best.pth.tar --img_path /src/notebooks/detect_text/icdar2015/ch4_test_images --result_save_path ./result --onnx_path ./onnx/DBnet.onnx --trt_path ./onnx/DBnet_batch.engine --batch_size 2 --max_size 1536 --add_padding \ No newline at end of file diff --git "a/pre_model/\346\226\260\345\273\272\346\226\207\346\234\254\346\226\207\346\241\243.txt" "b/pre_model/\346\226\260\345\273\272\346\226\207\346\234\254\346\226\207\346\241\243.txt" deleted file mode 100644 index e69de29..0000000 diff --git a/ptocr/__pycache__/__init__.cpython-36.pyc b/ptocr/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..04928da Binary files /dev/null and b/ptocr/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/__pycache__/optimizer.cpython-36.pyc b/ptocr/__pycache__/optimizer.cpython-36.pyc new file mode 100644 index 0000000..c85c005 Binary files /dev/null and b/ptocr/__pycache__/optimizer.cpython-36.pyc differ diff --git a/ptocr/dataloader/DetLoad/DBProcess.py b/ptocr/dataloader/DetLoad/DBProcess.py index 7fbe42b..749baed 100644 --- a/ptocr/dataloader/DetLoad/DBProcess.py +++ b/ptocr/dataloader/DetLoad/DBProcess.py @@ -94,7 +94,88 @@ def __getitem__(self, index): return img,gt,gt_mask,thresh_map,thresh_mask +class DBProcessTrainMul(data.Dataset): + def __init__(self,config): + super(DBProcessTrainMul,self).__init__() + self.crop_shape = config['base']['crop_shape'] + self.MBM = MakeBorderMap() + self.TSM = Random_Augment(self.crop_shape) + self.MSM = MakeSegMap(shrink_ratio = config['base']['shrink_ratio']) + img_list, label_list = self.get_base_information(config['trainload']['train_file']) + self.img_list = img_list + self.label_list = label_list + + def order_points(self, pts): + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + diff = np.diff(pts, axis=1) + rect[1] = pts[np.argmin(diff)] + rect[3] = pts[np.argmax(diff)] + return rect + + def get_bboxes(self,gt_path): + polys = [] + tags = [] + classes = [] + with open(gt_path, 'r', encoding='utf-8') as fid: + lines = fid.readlines() + for line in lines: + line = line.replace('\ufeff', '').replace('\xef\xbb\xbf', '') + gt = line.split(',') + if "#" in gt[-1]: + tags.append(True) + classes.append(-2) + else: + tags.append(False) + classes.append(int(gt[-1])) + # box = [int(gt[i]) for i in range(len(gt)//2*2)] + box = [int(gt[i]) for i in range(8)] + polys.append(box) + return np.array(polys), tags, classes + + def get_base_information(self,train_txt_file): + label_list = [] + img_list = [] + with open(train_txt_file,'r',encoding='utf-8') as fid: + lines = fid.readlines() + for line in lines: + line = line.strip('\n').split('\t') + img_list.append(line[0]) + result = self.get_bboxes(line[1]) + label_list.append(result) + return img_list,label_list + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, index): + + img = Image.open(self.img_list[index]).convert('RGB') + img = np.array(img)[:,:,::-1] + + polys, dontcare, classes = self.label_list[index] + + img, polys = self.TSM.random_scale(img, polys, self.crop_shape[0]) + img, polys = self.TSM.random_rotate(img, polys) + img, polys = self.TSM.random_flip(img, polys) + img, polys, classes,dontcare = self.TSM.random_crop_db_mul(img, polys,classes, dontcare) + img, gt, classes,gt_mask = self.MSM.process_mul(img, polys, classes,dontcare) + img, thresh_map, thresh_mask = self.MBM.process(img, polys, dontcare) + + img = Image.fromarray(img).convert('RGB') + img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) + img = self.TSM.normalize_img(img) + + + gt = torch.from_numpy(gt).float() + gt_classes = torch.from_numpy(classes).long() + gt_mask = torch.from_numpy(gt_mask).float() + thresh_map = torch.from_numpy(thresh_map).float() + thresh_mask = torch.from_numpy(thresh_mask).float() + return img,gt,gt_classes,gt_mask,thresh_map,thresh_mask class DBProcessTest(data.Dataset): def __init__(self,config): diff --git a/ptocr/dataloader/DetLoad/MakeSegMap.py b/ptocr/dataloader/DetLoad/MakeSegMap.py index 3a72c7d..02a8bff 100644 --- a/ptocr/dataloader/DetLoad/MakeSegMap.py +++ b/ptocr/dataloader/DetLoad/MakeSegMap.py @@ -27,6 +27,7 @@ def __init__(self, algorithm='DB',min_text_size = 8,shrink_ratio = 0.4,is_traini self.shrink_ratio = shrink_ratio self.is_training = is_training self.algorithm = algorithm + def process(self, img,polys,dontcare): ''' requied keys: @@ -77,7 +78,61 @@ def process(self, img,polys,dontcare): if self.algorithm == 'PAN': return img,gt_text,gt_text_key,gt,gt_kernel_key,mask return img,gt,mask + + def process_mul(self, img,polys,classes,dontcare): + ''' + requied keys: + image, polygons, ignore_tags, filename + adding keys: + mask + ''' + h, w = img.shape[:2] + if self.is_training: + polys, dontcare = self.validate_polygons( + polys, dontcare, h, w) + gt = np.zeros((h, w), dtype=np.float32) + gt_class = np.zeros((h, w), dtype=np.float32) + mask = np.ones((h, w), dtype=np.float32) + + if self.algorithm =='PAN': + gt_text = np.zeros((h, w), dtype=np.float32) + gt_text_key = np.zeros((h, w), dtype=np.float32) + gt_kernel_key = np.zeros((h, w), dtype=np.float32) + for i in range(len(polys)): + polygon = polys[i] + height = max(polygon[:, 1]) - min(polygon[:, 1]) + width = max(polygon[:, 0]) - min(polygon[:, 0]) + if dontcare[i] or min(height, width) < self.min_text_size: + cv2.fillPoly(mask, polygon.astype( + np.int32)[np.newaxis, :, :], 0) + dontcare[i] = True + else: + if self.algorithm == 'PAN': + cv2.fillPoly(gt_text, [polygon.astype(np.int32)], 1) + cv2.fillPoly(gt_text_key, [polygon.astype(np.int32)], i + 1) + polygon_shape = Polygon(polygon) + distance = polygon_shape.area * \ + (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(l) for l in polys[i]] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + shrinked = padding.Execute(-distance) + if shrinked == []: + cv2.fillPoly(mask, polygon.astype( + np.int32)[np.newaxis, :, :], 0) + dontcare[i] = True + continue + shrinked = np.array(shrinked[0]).reshape(-1, 2) + cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1) + cv2.fillPoly(gt_class, polygon.astype(np.int32)[np.newaxis, :, :], 1+classes[i]) + if self.algorithm == 'PAN': + cv2.fillPoly(gt_kernel_key, [shrinked.astype(np.int32)], i + 1) + if self.algorithm == 'PAN': + return img,gt_text,gt_text_key,gt,gt_kernel_key,mask + return img,gt,gt_class,mask + def validate_polygons(self, polygons, ignore_tags, h, w): ''' polygons (numpy.array, required): of shape (num_instances, num_points, 2) diff --git a/ptocr/dataloader/DetLoad/__pycache__/DBProcess.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/DBProcess.cpython-36.pyc new file mode 100644 index 0000000..2035a23 Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/DBProcess.cpython-36.pyc differ diff --git a/ptocr/dataloader/DetLoad/__pycache__/MakeBorderMap.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/MakeBorderMap.cpython-36.pyc new file mode 100644 index 0000000..014c97a Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/MakeBorderMap.cpython-36.pyc differ diff --git a/ptocr/dataloader/DetLoad/__pycache__/MakeSegMap.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/MakeSegMap.cpython-36.pyc new file mode 100644 index 0000000..aaafb5a Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/MakeSegMap.cpython-36.pyc differ diff --git a/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori.cpython-36.pyc new file mode 100644 index 0000000..93c4b9f Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori.cpython-36.pyc differ diff --git a/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori1.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori1.cpython-36.pyc new file mode 100644 index 0000000..e3eaeab Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori1.cpython-36.pyc differ diff --git a/ptocr/dataloader/DetLoad/__pycache__/__init__.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..1c81771 Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/dataloader/DetLoad/__pycache__/transform_img.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/transform_img.cpython-36.pyc new file mode 100644 index 0000000..333ac1e Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/transform_img.cpython-36.pyc differ diff --git a/ptocr/dataloader/DetLoad/transform_img.py b/ptocr/dataloader/DetLoad/transform_img.py index ce02ead..adbb77a 100644 --- a/ptocr/dataloader/DetLoad/transform_img.py +++ b/ptocr/dataloader/DetLoad/transform_img.py @@ -76,6 +76,38 @@ def process(self, img, polys, dont_care): new_dotcare.append(dont_care[i]) return img, new_polys, new_dotcare + + def process_mul(self, img, polys, classes, dont_care): + all_care_polys = [] + for i in range(len(dont_care)): + if (dont_care[i] is False): + all_care_polys.append(polys[i]) + crop_x, crop_y, crop_w, crop_h = self.crop_area(img, all_care_polys) + scale_w = self.size[0] / crop_w + scale_h = self.size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + padimg = np.zeros( + (self.size[1], self.size[0], img.shape[2]), img.dtype) + padimg[:h, :w] = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + img = padimg + + new_polys = [] + new_dotcare = [] + new_classes = [] + + for i in range(len(polys)): + poly = polys[i] + poly = ((np.array(poly) - + (crop_x, crop_y)) * scale) + if not self.is_poly_outside_rect(poly, 0, 0, w, h): + new_polys.append(poly) + new_dotcare.append(dont_care[i]) + new_classes.append(classes[i]) + + return img, new_polys,new_classes, new_dotcare def is_poly_in_rect(self, poly, x, y, w, h): poly = np.array(poly) @@ -259,6 +291,10 @@ def random_flip(self, img, polys): def random_crop_db(self, img, polys, dont_care): img, new_polys, new_dotcare = self.random_crop_data.process(img, polys, dont_care) return img, new_polys, new_dotcare + + def random_crop_db_mul(self, img, polys,classes, dont_care): + img, new_polys, new_classes,new_dotcare = self.random_crop_data.process_mul(img, polys, classes,dont_care) + return img, new_polys,new_classes, new_dotcare def random_crop_pse(self, imgs, ): h, w = imgs[0].shape[0:2] diff --git a/ptocr/dataloader/RecLoad/CRNNProcess.py b/ptocr/dataloader/RecLoad/CRNNProcess.py index 68a370a..fb32fe2 100644 --- a/ptocr/dataloader/RecLoad/CRNNProcess.py +++ b/ptocr/dataloader/RecLoad/CRNNProcess.py @@ -1,8 +1,14 @@ +import lmdb import torch -import torchvision.transforms as transforms +import six,re,glob,os +import numpy as np +from PIL import Image,ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True from torch.utils.data import Dataset -from .DataAgument import transform_image_add,transform_img_shape,transform_image_one -from PIL import Image +import torchvision.transforms as transforms +from ptocr.dataloader.RecLoad.DataAgument import transform_img_shape,DataAugment +from ptocr.utils.util_function import create_module,PILImageToCV,CVImageToPIL + def get_img(path,is_gray = False): img = Image.open(path).convert('RGB') @@ -10,54 +16,36 @@ def get_img(path,is_gray = False): img = img.convert('L') return img -class CRNNProcessTrain(Dataset): - def __init__(self, config): - super(CRNNProcessTrain,self).__init__() - with open(config['trainload']['train_file'],'r',encoding='utf-8') as fid: - self.label_list = [] - self.image_list = [] - - for line in fid.readlines(): - line = line.strip('\n').replace('\ufeff','').split('\t') - self.label_list.append(line[1]) - self.image_list.append(line[0]) - +class CRNNProcessLmdbLoad(Dataset): + def __init__(self, config,lmdb_type): self.config = config + self.lmdb_type = lmdb_type - def __len__(self): - return len(self.label_list) + if lmdb_type=='train': + lmdb_file = config['trainload']['train_file'] + workers = config['trainload']['num_workers'] + elif lmdb_type == 'val': + lmdb_file = config['valload']['val_file'] + workers = config['valload']['num_workers'] + else: + assert 1 == 1 + raise('lmdb_type error !!!') - def __getitem__(self, index): - img = get_img(self.image_list[index],is_gray=self.config['base']['is_gray']) -# img = transform_image_add(img) - img = transform_image_one(img) - img = transform_img_shape(img,self.config['base']['img_shape']) - img = Image.fromarray(img) - img = transforms.ToTensor()(img) - img.sub_(0.5).div_(0.5) - label = self.label_list[index] - return (img,label) - - -class CRNNProcessTrainLmdb(Dataset): - def __init__(self, config): - self.env = lmdb.open( - config['trainload']['train_file'], - max_readers=1, - readonly=True, - lock=False, - readahead=False, - meminit=False) + self.env = lmdb.open(lmdb_file,max_readers=workers,readonly=True,lock=False,readahead=False,meminit=False) if not self.env: - print('cannot creat lmdb from %s' % (root)) + print('cannot creat lmdb from %s' % (lmdb_file)) sys.exit(0) with self.env.begin(write=False) as txn: nSamples = int(txn.get('num-samples'.encode('utf-8'))) self.nSamples = nSamples - self.config = config + self.transform_label = create_module(config['label_transform']['label_function']) + + self.bg_img = [] + for path in glob.glob(os.path.join(config['trainload']['bg_path'], '*')): + self.bg_img.append(path) def __len__(self): return self.nSamples @@ -79,22 +67,29 @@ def __getitem__(self, index): return self[index + 1] label_key = 'label-%09d' % index - label = txn.get(label_key.encode('utf-8')) - + label = txn.get(label_key.encode('utf-8')).decode() + label = self.transform_label(label, char_type=self.config['label_transform']['char_type'],t_type=self.config['label_transform']['t_type']) if self.config['base']['is_gray']: img = img.convert('L') - img = np.array(img) - - img = transform_image_one(img) - img = transform_img_shape(img,self.config['base']['img_shape']) - img = Image.fromarray(img) + img = PILImageToCV(img,self.config['base']['is_gray']) + if self.lmdb_type == 'train': + try: + bg_index = np.random.randint(0, len(self.bg_img)) + bg_img = PILImageToCV(get_img(self.bg_img[bg_index]),self.config['base']['is_gray']) + img = transform_img_shape(img, self.config['base']['img_shape']) + img = DataAugment(img, bg_img, self.config['base']['img_shape']) + img = transform_img_shape(img, self.config['base']['img_shape']) + except IOError: + print('Corrupted image for %d' % index) + return self[index + 1] + elif self.lmdb_type == 'val': + img = transform_img_shape(img, self.config['base']['img_shape']) + img = CVImageToPIL(img,self.config['base']['is_gray']) img = transforms.ToTensor()(img) img.sub_(0.5).div_(0.5) - - return (img, label) - + return (img, label) + class alignCollate(object): - def __init__(self,): pass def __call__(self, batch): @@ -102,28 +97,4 @@ def __call__(self, batch): images = torch.stack(images,0) return images,labels -class CRNNProcessTest(Dataset): - def __init__(self, config): - super(CRNNProcessTest,self).__init__() - with open(config['testload']['test_file'],'r',encoding='utf-8') as fid: - self.label_list = [] - self.image_list = [] - - for line in fid.readlines(): - line = line.strip('\n').replace('\ufeff','').split('\t') - self.label_list.append(line[1]) - self.image_list.append(line[0]) - - self.config = config - - def __len__(self): - return len(self.label_list) - - def __getitem__(self, index): - img = get_img(self.image_list[index],is_gray=self.config['base']['is_gray']) - img = transform_img_shape(img,self.config['base']['img_shape']) - img = Image.fromarray(img) - img = transforms.ToTensor()(img) - img.sub_(0.5).div_(0.5) - label = self.label_list[index] - return (img,label) \ No newline at end of file + \ No newline at end of file diff --git a/ptocr/dataloader/RecLoad/CRNNProcess1.py b/ptocr/dataloader/RecLoad/CRNNProcess1.py new file mode 100644 index 0000000..de9c6ec --- /dev/null +++ b/ptocr/dataloader/RecLoad/CRNNProcess1.py @@ -0,0 +1,244 @@ +#-*- coding:utf-8 _*- +""" +@author:fxw +@file: CRNNProcess.py +@time: 2021/03/23 +""" + +import lmdb +import torch +import six,re,glob,os +import numpy as np +from PIL import Image +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True +from torch.utils.data import Dataset +import torchvision.transforms as transforms +from ptocr.dataloader.RecLoad.DataAgument import transform_img_shape,transform_image_one +from ptocr.utils.util_function import create_module + + +def get_img(path,is_gray = False): + try: + img = Image.open(path).convert('RGB') + except: + print(path) + img = np.zeros(3,320,32) + img = Image.fromarray(img).convert('RGB') + if(is_gray): + img = img.convert('L') + return img + +class alignCollate(object): + def __init__(self, ): + pass + def __call__(self, batch): + images, labels = zip(*batch) + images = torch.stack(images, 0) + return images, labels + +class CRNNProcessTrainLmdb(Dataset): + def __init__(self, config,lmdb_path): + self.env = lmdb.open( + lmdb_path, + max_readers=config['trainload']['num_workers'], + readonly=True, + lock=False, + readahead=False, + meminit=False) + + with self.env.begin(write=False) as txn: + nSamples = int(txn.get('num-samples'.encode('utf-8'))) + self.nSamples = nSamples + + self.config = config + self.transform_label = create_module(config['transform']['function']) + self.bg_img = [] + for path in glob.glob(os.path.join(config['trainload']['bg_path'],'*')): + self.bg_img.append(path) + + + def __len__(self): + return self.nSamples + + def __getitem__(self, index): + assert index <= len(self), 'index range error' + index += 1 + with self.env.begin(write=False) as txn: + + img_key = 'image-%09d' % index + imgbuf = txn.get(img_key.encode('utf-8')) + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + + try: + img = Image.open(buf).convert('RGB') + except IOError: + print('IO image for %d' % index) + return self[index + 1] + + label_key = 'label-%09d' % index + label = txn.get(label_key.encode('utf-8')) + + if self.config['base']['is_gray']: + img = img.convert('L') + + img = np.array(img) + if isinstance(label,bytes): + label = label.decode() + + label = self.transform_label(label,char_type=self.config['transform']['char_type'],t_type=self.config['transform']['t_type']) + + + try: + bg_index = np.random.randint(0,len(self.bg_img)) + bg_img = np.array(get_img(self.bg_img[bg_index])) + img = transform_img_shape(img,self.config['base']['img_shape']) + img = transform_image_one(img,bg_img,self.config['base']['img_shape']) + img = transform_img_shape(img,self.config['base']['img_shape']) + except: + print('Corrupted image for %d' % index) + img = np.zeros((32,100)).astype(np.uint8) + + img = Image.fromarray(img) + img = transforms.ToTensor()(img) + img.sub_(0.5).div_(0.5) + + return (img, label) + +def GetDataLoad(config,data_type='train'): + if data_type == 'train': + lmdb_path_list = config['trainload']['train_file'] + ratio = config['trainload']['batch_ratio'] + elif data_type == 'val': + lmdb_path_list = config['valload']['val_file'] + + num = len(lmdb_path_list) + train_data_loaders = [] + sum_num = 0 + for i in range(num): + if data_type=='train': + if i==num-1: + batch_size = config['trainload']['batch_size'] - sum_num + else: + batch_size = int(config['trainload']['batch_size']*ratio[i])//2*2 + sum_num+=batch_size + dataset = CRNNProcessTrainLmdb(config, lmdb_path_list[i]) + num_workers = config['trainload']['num_workers'] + shuffle = True + elif data_type == 'val': + batch_size = 1 + num_workers = config['valload']['num_workers'] + dataset = CRNNProcessTrainLmdb(config, lmdb_path_list[i]) + shuffle = False + + train_data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=alignCollate(), + drop_last=True, + pin_memory=True) + train_data_loaders.append(train_data_loader) + return train_data_loaders + +def GetValDataLoad(config): + + val_dir = config['valload']['dir'] + root= config['valload']['root'] + + num = len(val_dir) + data_loaders = [] + sum_num = 0 + for i in range(num): + + batch_size = 1 + num_workers = config['valload']['num_workers'] + config['valload']['test_file'] = os.path.join(root,val_dir[i],'val_train.txt') + dataset = CRNNProcessTest(config) + shuffle = False + + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=alignCollate(), + drop_last=True, + pin_memory=True) + data_loaders.append(data_loader) + return data_loaders + +class CRNNProcessTest(Dataset): + def __init__(self, config): + super(CRNNProcessTest,self).__init__() + with open(config['valload']['test_file'],'r',encoding='utf-8') as fid: + self.label_list = [] + self.image_list = [] + + for line in fid.readlines(): + line = line.strip('\n').replace('\ufeff','').split('\t') + self.label_list.append(line[1]) + self.image_list.append(line[0]) + + self.config = config + + def __len__(self): + return len(self.label_list) + + def __getitem__(self, index): + img = get_img(self.image_list[index],is_gray=self.config['base']['is_gray']) + img = transform_img_shape(img,self.config['base']['img_shape']) + img = Image.fromarray(img) + img = transforms.ToTensor()(img) + img.sub_(0.5).div_(0.5) + label = self.label_list[index] + return (img,label) + + +# if __name__ == "__main__": + +# config = {} +# config['base'] = {} +# config['base']['img_shape'] = [32,100] +# config['trainload'] = {} +# config['valload'] = {} +# config['trainload']['bg_path']='./bg_img/' +# config['base']['is_gray'] = True +# config['trainload']['train_file'] =['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth'] + +# config['valload']['val_file'] = ['D:\BaiduNetdiskDownload\data\evaluation\CUTE80', +# 'D:\BaiduNetdiskDownload\data\evaluation\IC03_860', +# 'D:\BaiduNetdiskDownload\data\evaluation\IC03_867'] +# config['trainload']['num_workers'] = 0 +# config['valload']['num_workers'] = 0 +# config['trainload']['batch_size'] = 128 +# config['trainload']['batch_ratio'] = [0.33, 0.33, 0.33] +# config['transform']={} + +# config['transform']['char_type'] = 'En' +# config['transform']['t_type']='lower' +# config['transform']['function'] = 'ptocr.dataloader.RecLoad.DataAgument,transform_label' +# import time +# train_data_loaders = GetDataLoad(config,data_type='train') +# # import pdb +# # pdb.set_trace() + +# # t = time.time() +# # for idx,data1 in enumerate(train_data_loaders[0]): +# # pass +# # print(time.time()-t) + +# data1 = enumerate(train_data_loaders[0]) +# for i in range(100): +# try: +# t = time.time() +# index,(data,label) = next(data1) +# print(time.time()-t) +# print(label) +# except: +# print('end') + + \ No newline at end of file diff --git a/ptocr/dataloader/RecLoad/DataAgument.py b/ptocr/dataloader/RecLoad/DataAgument.py index 1e797ce..c9af073 100644 --- a/ptocr/dataloader/RecLoad/DataAgument.py +++ b/ptocr/dataloader/RecLoad/DataAgument.py @@ -4,7 +4,7 @@ @file: DataAgument.py @time: 2019/06/06 """ -import cv2 +import cv2,re import numpy as np from skimage.util import random_noise import os @@ -113,28 +113,16 @@ def perform_operation(self, images): dy = random.randint(-self.magnitude, self.magnitude) x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a] - polygons[a] = [x1, y1, - x2, y2, - x3 + dx, y3 + dy, - x4, y4] + polygons[a] = [x1, y1,x2, y2,x3 + dx, y3 + dy,x4, y4] x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b] - polygons[b] = [x1, y1, - x2 + dx, y2 + dy, - x3, y3, - x4, y4] + polygons[b] = [x1, y1, x2 + dx, y2 + dy,x3, y3,x4, y4] x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c] - polygons[c] = [x1, y1, - x2, y2, - x3, y3, - x4 + dx, y4 + dy] + polygons[c] = [x1, y1,x2, y2,x3, y3,x4 + dx, y4 + dy] x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d] - polygons[d] = [x1 + dx, y1 + dy, - x2, y2, - x3, y3, - x4, y4] + polygons[d] = [x1 + dx, y1 + dy,x2, y2,x3, y3,x4, y4] generated_mesh = [] for i in range(len(dimensions)): @@ -306,7 +294,7 @@ def transform_img_shape(img, img_shape): img = np.array(img) H, W = img_shape h, w = img.shape[:2] - new_w = int((float(H)/ h) * w) + new_w = int((float(H)/h) * w) if (new_w > W): img = cv2.resize(img, (W, H)) else: @@ -328,11 +316,23 @@ def random_crop(image): crop_img = crop_img[0:h - top_crop, :] return crop_img -def transform_image_one(image): +def get_background_Amg(img,bg_img,img_shape=[32,200]): + H,W = img_shape + x = np.random.randint(0, img.shape[0] - H + 1) + y = np.random.randint(0, img.shape[1] - W + 1) + img_crop = bg_img[x:x + H, y:y + W] + ratio = np.random.randint(1,4)/10.0+np.random.randint(0,10)/100.0 +# print(img.shape,img_crop.shape) + img = np.array(Image.fromarray(img).convert('RGB')) + dst = cv2.addWeighted(img, 1-ratio, img_crop, ratio, 2) + dst = np.array(Image.fromarray(dst).convert('L')) + return dst + +def DataAugment(image,bg_img,img_shape): image = np.array(image) if (np.random.choice([True, False], 1)[0]): dataAu = DataAugmentatonMore(image) - index = np.random.randint(0,10) + index = np.random.randint(0,11) if (index == 0): degree = np.random.randint(2, 6) angle = np.random.randint(0, 360) @@ -367,62 +367,24 @@ def transform_image_one(image): image = RandomAddLine(image) elif(index==9): image = random_crop(image) - del dataAu + elif(index==10): + image = get_background_Amg(image,bg_img,img_shape) + del dataAu,bg_img return image -def transform_image_add(image): - image = np.array(image) - is_transform = np.random.choice([True, False], 1)[0] - # image = RandomAddLine(image) - if (is_transform): - dataAu = DataAugmentatonMore(image) - is_transform = np.random.choice([True, False], 1)[0] - if (is_transform): - image = dataAu.image - image = random_dilute(image) - dataAu.image = image - is_transform = np.random.choice([True, False], 1)[0] - if (is_transform): - image = dataAu.image - image = Image.fromarray(image) - image = GetRandomDistortImage([image])[0] - dataAu.image = image - is_transform = np.random.choice([True, False], 1)[0] - if (is_transform): - degree = np.random.randint(2, 8) - angle = np.random.randint(0, 360) - image = dataAu.motion_blur(degree, angle) - dataAu.image = image - is_transform = np.random.choice([True, False], 1)[0] - if (is_transform): - id = np.random.randint(0, 3) - k_size = [3, 5, 7] - image = dataAu.gaussian_blur(k_size[id]) - dataAu.image = image - is_transform = np.random.choice([True, False], 1)[0] - if (is_transform): - alpha = np.random.uniform(0.6, 1.3) - image = dataAu.Contrast_and_Brightness(alpha) - dataAu.image = image - is_transform = np.random.choice([True, False], 1)[0] - if (is_transform): - types = ['top', 'botttom'] - id = np.random.randint(0, 1) - ratio = np.random.randint(0, 5) - image = dataAu.Perspective(ratio, types[id]) - dataAu.image = image - is_transform = np.random.choice([True, False], 1)[0] - if (is_transform and image.shape[0] > 20): - ratio = np.random.uniform(0.45, 0.7) - image = dataAu.resize_blur(ratio) - dataAu.image = image - del dataAu - return image - - - +def transform_label(label,char_type='En',t_type = 'lower'): + if char_type == 'En': + if t_type == 'lower': + label = ''.join(re.findall('[0-9a-zA-Z]+', label)).lower() + elif t_type == 'upper': + label = ''.join(re.findall('[0-9a-zA-Z]+', label)).upper() + else: + label = ''.join(re.findall('[0-9a-zA-Z]+', label)) + elif(char_type=='Ch'): + return label + return label diff --git a/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess.cpython-36.pyc b/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess.cpython-36.pyc new file mode 100644 index 0000000..2328be6 Binary files /dev/null and b/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess.cpython-36.pyc differ diff --git a/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess1.cpython-36.pyc b/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess1.cpython-36.pyc new file mode 100644 index 0000000..9fb131a Binary files /dev/null and b/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess1.cpython-36.pyc differ diff --git a/ptocr/dataloader/RecLoad/__pycache__/DataAgument.cpython-36.pyc b/ptocr/dataloader/RecLoad/__pycache__/DataAgument.cpython-36.pyc new file mode 100644 index 0000000..df27415 Binary files /dev/null and b/ptocr/dataloader/RecLoad/__pycache__/DataAgument.cpython-36.pyc differ diff --git a/ptocr/dataloader/RecLoad/__pycache__/__init__.cpython-36.pyc b/ptocr/dataloader/RecLoad/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..093e441 Binary files /dev/null and b/ptocr/dataloader/RecLoad/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/dataloader/__pycache__/__init__.cpython-36.pyc b/ptocr/dataloader/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..92fbd37 Binary files /dev/null and b/ptocr/dataloader/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/model/__pycache__/CommonFunction.cpython-36.pyc b/ptocr/model/__pycache__/CommonFunction.cpython-36.pyc new file mode 100644 index 0000000..cb2c691 Binary files /dev/null and b/ptocr/model/__pycache__/CommonFunction.cpython-36.pyc differ diff --git a/ptocr/model/__pycache__/__init__.cpython-36.pyc b/ptocr/model/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..5211216 Binary files /dev/null and b/ptocr/model/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc b/ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc index 31b4481..6b35176 100644 Binary files a/ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc b/ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc index 1c97390..1f3381f 100644 Binary files a/ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc and b/ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc differ diff --git a/ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc b/ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc index 608dd16..ede5104 100644 Binary files a/ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc and b/ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc differ diff --git a/ptocr/model/architectures/__pycache__/stn.cpython-36.pyc b/ptocr/model/architectures/__pycache__/stn.cpython-36.pyc new file mode 100644 index 0000000..ceb29c2 Binary files /dev/null and b/ptocr/model/architectures/__pycache__/stn.cpython-36.pyc differ diff --git a/ptocr/model/architectures/__pycache__/stn_head.cpython-36.pyc b/ptocr/model/architectures/__pycache__/stn_head.cpython-36.pyc new file mode 100644 index 0000000..62560ad Binary files /dev/null and b/ptocr/model/architectures/__pycache__/stn_head.cpython-36.pyc differ diff --git a/ptocr/model/architectures/__pycache__/tps_spatial_transformer.cpython-36.pyc b/ptocr/model/architectures/__pycache__/tps_spatial_transformer.cpython-36.pyc new file mode 100644 index 0000000..e697196 Binary files /dev/null and b/ptocr/model/architectures/__pycache__/tps_spatial_transformer.cpython-36.pyc differ diff --git a/ptocr/model/architectures/det_model.py b/ptocr/model/architectures/det_model.py index 0f91513..78f7337 100644 --- a/ptocr/model/architectures/det_model.py +++ b/ptocr/model/architectures/det_model.py @@ -20,9 +20,15 @@ def __init__(self, config): self.head = create_module(config['head']['function']) \ (config['base']['in_channels'], config['base']['inner_channels']) - + self.mulclass = False if (config['base']['algorithm']) == 'DB': - self.seg_out = create_module(config['segout']['function'])(config['base']['inner_channels'], + if 'n_class' in config['base'].keys(): + self.mulclass = True + self.seg_out = create_module(config['segout']['function'])(config['base']['n_class'],config['base']['inner_channels'], + config['base']['k'], + config['base']['adaptive']) + else: + self.seg_out = create_module(config['segout']['function'])(config['base']['inner_channels'], config['base']['k'], config['base']['adaptive']) elif (config['base']['algorithm']) == 'PAN': @@ -40,14 +46,21 @@ def __init__(self, config): def forward(self, data): if self.training: if self.algorithm == "DB": - img, gt, gt_mask, thresh_map, thresh_mask = data + if self.mulclass: + img, gt,gt_class, gt_mask, thresh_map, thresh_mask = data + else: + img, gt, gt_mask, thresh_map, thresh_mask = data if torch.cuda.is_available(): + if self.mulclass: + gt_class = gt_class.cuda() img, gt, gt_mask, thresh_map, thresh_mask = \ img.cuda(), gt.cuda(), gt_mask.cuda(), thresh_map.cuda(), thresh_mask.cuda() gt_batch = dict(gt=gt) gt_batch['mask'] = gt_mask gt_batch['thresh_map'] = thresh_map gt_batch['thresh_mask'] = thresh_mask + if self.mulclass: + gt_batch['gt_class'] = gt_class elif self.algorithm == "PSE": img, gt_text, gt_kernels, train_mask = data @@ -98,8 +111,12 @@ def __init__(self, config): super(DetLoss, self).__init__() self.algorithm = config['base']['algorithm'] if (config['base']['algorithm']) == 'DB': - self.loss = create_module(config['loss']['function'])(config['loss']['l1_scale'], - config['loss']['bce_scale']) + if 'n_class' in config['base'].keys(): + self.loss = create_module(config['loss']['function'])(config['base']['n_class'],config['loss']['l1_scale'], + config['loss']['bce_scale'],config['loss']['class_scale']) + else: + self.loss = create_module(config['loss']['function'])(config['loss']['l1_scale'],config['loss']['bce_scale']) + elif (config['base']['algorithm']) == 'PAN': self.loss = create_module(config['loss']['function'])(config['loss']['kernel_rate'], config['loss']['agg_dis_rate']) diff --git a/ptocr/model/architectures/paddle_tps.py b/ptocr/model/architectures/paddle_tps.py new file mode 100644 index 0000000..c1f5a44 --- /dev/null +++ b/ptocr/model/architectures/paddle_tps.py @@ -0,0 +1,252 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F + +class ConvBNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_relu=False): + super(ConvBNLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + ) + self.bn = nn.BatchNorm2d( + out_channels) + if is_relu: + self.relu = nn.ReLU() + self.is_relu = is_relu + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.is_relu: + x = self.relu(x) + return x + + +class LocalizationNetwork(nn.Module): + def __init__(self, in_channels, num_fiducial, model_name): + super(LocalizationNetwork, self).__init__() + self.F = num_fiducial + F = num_fiducial + if model_name == "large": + num_filters_list = [64, 128, 256, 512] + fc_dim = 256 + else: + num_filters_list = [16, 32, 64, 128] + fc_dim = 64 + + block_list = [] + for fno in range(0, len(num_filters_list)): + num_filters = num_filters_list[fno] + conv = ConvBNLayer( + in_channels=in_channels, + out_channels=num_filters, + kernel_size=3, + is_relu=True) + block_list.append(conv) + if fno == len(num_filters_list) - 1: + pool = nn.AdaptiveAvgPool2d(1) + else: + pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + in_channels = num_filters + block_list.append(pool) + self.fc1 = nn.Linear( + in_channels, + fc_dim) + + # Init fc2 in LocalizationNetwork + self.fc2 = nn.Linear( + fc_dim, + F * 2) + self.out_channels = F * 2 + self.block_list = nn.Sequential(*block_list) + + def forward(self, x): + """ + Estimating parameters of geometric transformation + Args: + image: input + Return: + batch_C_prime: the matrix of the geometric transformation + """ + B = x.shape[0] + i = 0 + for block in self.block_list: + x = block(x) + x = x.squeeze(2).squeeze(2) + x = self.fc1(x) + + x = F.relu(x) + x = self.fc2(x) + x = x.reshape(shape=[-1, self.F, 2]) + return x + + def get_initial_fiducials(self): + """ see RARE paper Fig. 6 (a) """ + F = self.F + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + return initial_bias + + +class GridGenerator(nn.Module): + def __init__(self, in_channels, num_fiducial): + super(GridGenerator, self).__init__() + self.eps = 1e-6 + self.F = num_fiducial + self.fc = nn.Linear( + in_channels, + 6) + + def forward(self, batch_C_prime, I_r_size): + """ + Generate the grid for the grid_sampler. + Args: + batch_C_prime: the matrix of the geometric transformation + I_r_size: the shape of the input image + Return: + batch_P_prime: the grid for the grid_sampler + """ + C = self.build_C_paddle() + P = self.build_P_paddle(I_r_size) + + inv_delta_C_tensor = self.build_inv_delta_C_paddle(C) + P_hat_tensor = self.build_P_hat_paddle( + C, torch.Tensor(P)) + + inv_delta_C_tensor.stop_gradient = True + P_hat_tensor.stop_gradient = True + + batch_C_ex_part_tensor = self.get_expand_tensor(batch_C_prime) + + batch_C_ex_part_tensor.stop_gradient = True + + batch_C_prime_with_zeros = torch.cat([batch_C_prime, batch_C_ex_part_tensor],1) + batch_T = torch.matmul(inv_delta_C_tensor, batch_C_prime_with_zeros) + batch_P_prime = torch.matmul(P_hat_tensor, batch_T) + return batch_P_prime + + def build_C_paddle(self): + """ Return coordinates of fiducial points in I_r; C """ + F = self.F + ctrl_pts_x = torch.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = -1 * torch.ones([int(F / 2)]) + ctrl_pts_y_bottom = torch.ones([int(F / 2)]) + ctrl_pts_top = torch.stack([ctrl_pts_x, ctrl_pts_y_top], 1) + ctrl_pts_bottom = torch.stack([ctrl_pts_x, ctrl_pts_y_bottom],1) + C = torch.cat([ctrl_pts_top, ctrl_pts_bottom], 0) + return C # F x 2 + + def build_P_paddle(self, I_r_size): + I_r_height, I_r_width = I_r_size + I_r_grid_x = (torch.arange( + -I_r_width, I_r_width, 2) + 1.0 + ) / torch.Tensor(np.array([I_r_width])) + + I_r_grid_y = (torch.arange( + -I_r_height, I_r_height, 2) + 1.0 + ) / torch.Tensor(np.array([I_r_height])) + + # P: self.I_r_width x self.I_r_height x 2 + P = torch.stack(torch.meshgrid(I_r_grid_x, I_r_grid_y), 2) + P = P.permute(1, 0, 2) + # n (= self.I_r_width x self.I_r_height) x 2 + return P.reshape([-1, 2]) + + def build_inv_delta_C_paddle(self, C): + """ Return inv_delta_C which is needed to calculate T """ + F = self.F + hat_C = torch.zeros((F, F)) # F x F + for i in range(0, F): + for j in range(i, F): + if i == j: + hat_C[i, j] = 1 + else: + r = torch.norm(C[i] - C[j]) + hat_C[i, j] = r + hat_C[j, i] = r + hat_C = (hat_C**2) * torch.log(hat_C) + delta_C = torch.cat( # F+3 x F+3 + [ + torch.cat( + [torch.ones( + (F, 1)), C, hat_C], 1), # F x F+3 + torch.cat( + [ + torch.zeros( + (2, 3)), torch.transpose( + C, 1, 0) + ],1), # 2 x F+3 + torch.cat( + [ + torch.zeros( + (1, 3)), torch.ones( + (1, F)) + ],1) # 1 x F+3 +],0) + inv_delta_C = torch.inverse(delta_C) + return inv_delta_C # F+3 x F+3 + + def build_P_hat_paddle(self, C, P): + F = self.F + eps = self.eps + n = P.shape[0] # n (= self.I_r_width x self.I_r_height) + # P_tile: n x 2 -> n x 1 x 2 -> n x F x 2 + P_tile = P.unsqueeze(1).repeat(1, F, 1) + C_tile = torch.unsqueeze(C, 0) # 1 x F x 2 + P_diff = P_tile - C_tile # n x F x 2 + # rbf_norm: n x F + rbf_norm = torch.norm(P_diff, p=2, dim=2, keepdim=False) + + # rbf: n x F + rbf = torch.mul( + torch.square(rbf_norm), torch.log(rbf_norm + eps)) + P_hat = torch.cat( + [torch.ones( + (n, 1)), P, rbf],1) + return P_hat # n x F+3 + + def get_expand_tensor(self, batch_C_prime): + B, H, C = batch_C_prime.shape + batch_C_prime = batch_C_prime.reshape([B, H * C]) + batch_C_ex_part_tensor = self.fc(batch_C_prime) + batch_C_ex_part_tensor = batch_C_ex_part_tensor.reshape([-1, 3, 2]) + return batch_C_ex_part_tensor + + +class TPS(nn.Module): + def __init__(self, in_channels, num_fiducial, model_name): + super(TPS, self).__init__() + self.loc_net = LocalizationNetwork(in_channels, num_fiducial, + model_name) + self.grid_generator = GridGenerator(self.loc_net.out_channels, + num_fiducial) + self.out_channels = in_channels + + def forward(self, image): + batch_C_prime = self.loc_net(image) + batch_P_prime = self.grid_generator(batch_C_prime, image.shape[2:]) + batch_P_prime = batch_P_prime.reshape( + [-1, image.shape[2], image.shape[3], 2]) + batch_I_r = F.grid_sample(image, grid=batch_P_prime,align_corners=True) + return batch_I_r + diff --git a/ptocr/model/architectures/rec_model.py b/ptocr/model/architectures/rec_model.py index 9df17b2..9f5c2d4 100644 --- a/ptocr/model/architectures/rec_model.py +++ b/ptocr/model/architectures/rec_model.py @@ -6,27 +6,107 @@ """ import torch import torch.nn as nn +import torch.nn.functional as F from .. import create_module +import cv2 class RecModel(nn.Module): def __init__(self, config): super(RecModel, self).__init__() - self.algorithm = config['base']['algorithm'] + self.algorithm = config['base']['algorithm'] + self.backbone = create_module(config['backbone']['function'])(config['base']['pretrained'],config['base']['is_gray']) - self.head = create_module(config['head']['function'])( - use_conv=config['base']['use_conv'], - use_attention=config['base']['use_attention'], - use_lstm=config['base']['use_lstm'], - lstm_num=config['base']['lstm_num'], - inchannel=config['base']['inchannel'], - hiddenchannel=config['base']['hiddenchannel'], - classes=config['base']['classes']) + if self.algorithm=='CRNN': + self.head = create_module(config['head']['function'])( + use_attention=config['base']['use_attention'], + use_lstm=config['base']['use_lstm'], + time_step=config['base']['img_shape'][1]//4, + lstm_num=config['base']['lstm_num'], + inchannel=config['base']['inchannel'], + hiddenchannel=config['base']['hiddenchannel'], + classes=config['base']['classes']) + elif self.algorithm=='FC': + self.head = create_module(config['head']['function'])(in_channels=config['base']['in_channels'], + out_channels=config['base']['out_channels'], + max_length = config['base']['max_length'], + num_class = config['base']['num_class']) + + def forward(self, x): + x1 = self.backbone(x) + x1,feau = self.head(x1) + return x1,feau + +# class RecModel(nn.Module): +# def __init__(self, config): +# super(RecModel, self).__init__() +# self.algorithm = config['base']['algorithm'] +# if config['base']['is_gray']: +# in_planes = 1 +# else: +# in_planes = 3 + +# self.backbone = create_module(config['backbone']['function'])(config['base']['pretrained'],config['base']['is_gray']) +# self.head = create_module(config['head']['function'])( +# use_conv=config['base']['use_conv'], +# use_attention=config['base']['use_attention'], +# use_lstm=config['base']['use_lstm'], +# lstm_num=config['base']['lstm_num'], +# inchannel=config['base']['inchannel'], +# hiddenchannel=config['base']['hiddenchannel'], +# classes=config['base']['classes']) + +# self.stn_head = create_module(config['stn']['function'])(in_planes,config) + +# def forward(self, x): +# cv2.imwrite('stn_ori1.jpg',(x[0,0].cpu().detach().numpy()*0.5+0.5)*255) + +# x1= self.stn_head(x) +# cv2.imwrite('stn1.jpg',(x1[0,0].cpu().detach().numpy()*0.5+0.5)*255) + +# x1 = self.backbone(x1) +# x1 = self.head(x1) +# return x1 + +# class RecModel(nn.Module): +# def __init__(self, config): +# super(RecModel, self).__init__() +# self.algorithm = config['base']['algorithm'] +# self.tps_inputsize = config['stn']['tps_inputsize'] +# if config['base']['is_gray']: +# in_planes = 1 +# else: +# in_planes = 3 + +# self.backbone = create_module(config['backbone']['function'])(config['base']['pretrained'],config['base']['is_gray']) +# self.head = create_module(config['head']['function'])( +# use_conv=config['base']['use_conv'], +# use_attention=config['base']['use_attention'], +# use_lstm=config['base']['use_lstm'], +# lstm_num=config['base']['lstm_num'], +# inchannel=config['base']['inchannel'], +# hiddenchannel=config['base']['hiddenchannel'], +# classes=config['base']['classes']) + +# self.tps = create_module(config['stn']['t_function'])( output_image_size=tuple(config['stn']['tps_outputsize']), +# num_control_points=config['stn']['num_control_points'], +# margins=tuple(config['stn']['tps_margins'])) +# self.stn_head = create_module(config['stn']['function'])(in_planes=in_planes, +# num_ctrlpoints=config['stn']['num_control_points'], +# activation=config['stn']['stn_activation']) + +# def forward(self, x): +# cv2.imwrite('stn_ori.jpg',(x[0,0].cpu().detach().numpy()*0.5+0.5)*255) +# stn_input = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True) +# stn_img_feat, ctrl_points = self.stn_head(stn_input) - def forward(self, img): - x = self.backbone(img) - x = self.head(x) - return x +# x1, _ = self.tps(x, ctrl_points) + +# cv2.imwrite('stn.jpg',(x1[0,0].cpu().detach().numpy()*0.5+0.5)*255) + +# x1 = self.backbone(x1) +# x1 = self.head(x1) +# return x1 class RecLoss(nn.Module): @@ -35,6 +115,8 @@ def __init__(self, config): self.algorithm = config['base']['algorithm'] if (config['base']['algorithm']) == 'CRNN': self.loss = create_module(config['loss']['function'])(config) + elif self.algorithm=='FC': + self.loss = create_module(config['loss']['function'])(ignore_index = config['base']['ignore_index']) else: assert True == False, ('not support this algorithm !!!') diff --git a/ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc b/ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc index a34ccd4..1f5e306 100644 Binary files a/ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc index bc7d5fc..1a840d5 100644 Binary files a/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/det_mobilev3_dcd.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_mobilev3_dcd.cpython-36.pyc new file mode 100644 index 0000000..24f3a9d Binary files /dev/null and b/ptocr/model/backbone/__pycache__/det_mobilev3_dcd.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc index 58ed83d..cd01ea5 100644 Binary files a/ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc index bf5e35e..6998a2a 100644 Binary files a/ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc index c3ad5ee..cc86e8b 100644 Binary files a/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc index 705ca8f..6786036 100644 Binary files a/ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/rec_crnn_backbone.cpython-36.pyc b/ptocr/model/backbone/__pycache__/rec_crnn_backbone.cpython-36.pyc new file mode 100644 index 0000000..89a161d Binary files /dev/null and b/ptocr/model/backbone/__pycache__/rec_crnn_backbone.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/rec_mobilev3_bd.cpython-36.pyc b/ptocr/model/backbone/__pycache__/rec_mobilev3_bd.cpython-36.pyc new file mode 100644 index 0000000..cfff19c Binary files /dev/null and b/ptocr/model/backbone/__pycache__/rec_mobilev3_bd.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc b/ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc index abe0e48..b16abd4 100644 Binary files a/ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc differ diff --git a/ptocr/model/backbone/__pycache__/reg_resnet_bd.cpython-36.pyc b/ptocr/model/backbone/__pycache__/reg_resnet_bd.cpython-36.pyc new file mode 100644 index 0000000..3fc663a Binary files /dev/null and b/ptocr/model/backbone/__pycache__/reg_resnet_bd.cpython-36.pyc differ diff --git a/ptocr/model/backbone/det_densenet.py b/ptocr/model/backbone/det_densenet.py deleted file mode 100644 index 4bf2353..0000000 --- a/ptocr/model/backbone/det_densenet.py +++ /dev/null @@ -1,242 +0,0 @@ -#-*- coding:utf-8 _*- -""" -@author:fxw -@file: det_densenet.py -@time: 2020/08/07 -""" -import re -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo -from collections import OrderedDict - -__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] - - -model_urls = { - 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', - 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', - 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', - 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', -} - - -class _DenseLayer(nn.Sequential): - def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): - super(_DenseLayer, self).__init__() - self.add_module('norm1', nn.BatchNorm2d(num_input_features)), - self.add_module('relu1', nn.ReLU(inplace=True)), - self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * - growth_rate, kernel_size=1, stride=1, bias=False)), - self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), - self.add_module('relu2', nn.ReLU(inplace=True)), - self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, - kernel_size=3, stride=1, padding=1, bias=False)), - self.drop_rate = drop_rate - - def forward(self, x): - new_features = super(_DenseLayer, self).forward(x) - if self.drop_rate > 0: - new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) - return torch.cat([x, new_features], 1) - - -class _DenseBlock(nn.Sequential): - def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): - super(_DenseBlock, self).__init__() - for i in range(num_layers): - layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) - self.add_module('denselayer%d' % (i + 1), layer) - - -class _Transition(nn.Sequential): - def __init__(self, num_input_features, num_output_features): - super(_Transition, self).__init__() - self.add_module('norm', nn.BatchNorm2d(num_input_features)) - self.add_module('relu', nn.ReLU(inplace=True)) - self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, - kernel_size=1, stride=1, bias=False)) - self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) - - -class DenseNet(nn.Module): - r"""Densenet-BC model class, based on - `"Densely Connected Convolutional Networks" `_ - - Args: - growth_rate (int) - how many filters to add each layer (`k` in paper) - block_config (list of 4 ints) - how many layers in each pooling block - num_init_features (int) - the number of filters to learn in the first convolution layer - bn_size (int) - multiplicative factor for number of bottle neck layers - (i.e. bn_size * k features in the bottleneck layer) - drop_rate (float) - dropout rate after each dense layer - num_classes (int) - number of classification classes - """ - - def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), - num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): - - super(DenseNet, self).__init__() - - # First convolution - self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), - ('norm0', nn.BatchNorm2d(num_init_features)), - ('relu0', nn.ReLU(inplace=True)), - ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), - ])) - - # Each denseblock - num_features = num_init_features - for i, num_layers in enumerate(block_config): - block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, - bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) - self.features.add_module('denseblock%d' % (i + 1), block) - num_features = num_features + num_layers * growth_rate - if i != len(block_config) - 1: - trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) - self.features.add_module('transition%d' % (i + 1), trans) - num_features = num_features // 2 - - # Final batch norm - self.features.add_module('norm5', nn.BatchNorm2d(num_features)) - - # Linear layer - self.classifier = nn.Linear(num_features, num_classes) - - # Official init from torch repo. - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.constant_(m.bias, 0) - - # def forward(self, x): - # features = self.features(x) - # out = F.relu(features, inplace=True) - # out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) - # out = self.classifier(out) - # return out - - def forward(self, x): - out_list = [] - for layer in self.features.children(): - x = layer(x) - if isinstance(layer, _DenseBlock): - out_list.append(x) - p1, p2, p3, p4 = out_list - return p1, p2, p3, p4 - - - -def densenet121(pretrained=False, **kwargs): - r"""Densenet-121 model from - `"Densely Connected Convolutional Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), - **kwargs) - if pretrained: - # '.'s are no longer allowed in module names, but pervious _DenseLayer - # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. - # They are also in the checkpoints in model_urls. This pattern is used - # to find such keys. - pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - state_dict = model_zoo.load_url(model_urls['densenet121']) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + res.group(2) - state_dict[new_key] = state_dict[key] - del state_dict[key] - model.load_state_dict(state_dict) - return model - - -def densenet169(pretrained=False, **kwargs): - r"""Densenet-169 model from - `"Densely Connected Convolutional Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), - **kwargs) - if pretrained: - # '.'s are no longer allowed in module names, but pervious _DenseLayer - # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. - # They are also in the checkpoints in model_urls. This pattern is used - # to find such keys. - pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - state_dict = model_zoo.load_url(model_urls['densenet169']) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + res.group(2) - state_dict[new_key] = state_dict[key] - del state_dict[key] - model.load_state_dict(state_dict) - return model - - -def densenet201(pretrained=False, **kwargs): - r"""Densenet-201 model from - `"Densely Connected Convolutional Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), - **kwargs) - if pretrained: - # '.'s are no longer allowed in module names, but pervious _DenseLayer - # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. - # They are also in the checkpoints in model_urls. This pattern is used - # to find such keys. - pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - state_dict = model_zoo.load_url(model_urls['densenet201']) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + res.group(2) - state_dict[new_key] = state_dict[key] - del state_dict[key] - model.load_state_dict(state_dict) - return model - - -def densenet161(pretrained=False, **kwargs): - r"""Densenet-161 model from - `"Densely Connected Convolutional Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), - **kwargs) - if pretrained: - # '.'s are no longer allowed in module names, but pervious _DenseLayer - # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. - # They are also in the checkpoints in model_urls. This pattern is used - # to find such keys. - pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - state_dict = model_zoo.load_url(model_urls['densenet161']) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + res.group(2) - state_dict[new_key] = state_dict[key] - del state_dict[key] - model.load_state_dict(state_dict) - return model - diff --git a/ptocr/model/backbone/det_scnet.py b/ptocr/model/backbone/det_scnet.py deleted file mode 100644 index 714b84b..0000000 --- a/ptocr/model/backbone/det_scnet.py +++ /dev/null @@ -1,294 +0,0 @@ -#-*- coding:utf-8 _*- -""" -@author:fxw -@file: det_scnet.py -@time: 2020/08/07 -""" -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo - -__all__ = ['SCNet', 'scnet50', 'scnet50_v1d'] - -model_urls = { - 'scnet50': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet50-dc6a7e87.pth', - 'scnet50_v1d': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet50_v1d-4109d1e1.pth', -} - -class SCConv(nn.Module): - def __init__(self, inplanes, planes, stride, padding, dilation, groups, pooling_r, norm_layer): - super(SCConv, self).__init__() - self.k2 = nn.Sequential( - nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r,padding=1), - nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, - padding=padding, dilation=dilation, - groups=groups, bias=False), - norm_layer(planes), - ) - self.k3 = nn.Sequential( - nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, - padding=padding, dilation=dilation, - groups=groups, bias=False), - norm_layer(planes), - ) - self.k4 = nn.Sequential( - nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, - padding=padding, dilation=dilation, - groups=groups, bias=False), - norm_layer(planes), - ) - - def forward(self, x): - identity = x - - out = torch.sigmoid(torch.add(identity, F.interpolate(self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2) - out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2) - out = self.k4(out) # k4 - - return out - -class SCBottleneck(nn.Module): - """SCNet SCBottleneck - """ - expansion = 4 - pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv. - - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, bottleneck_width=32, - avd=False, dilation=1, is_first=False, - norm_layer=None): - super(SCBottleneck, self).__init__() - group_width = int(planes * (bottleneck_width / 64.)) * cardinality - self.conv1_a = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) - self.bn1_a = norm_layer(group_width) - self.conv1_b = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) - self.bn1_b = norm_layer(group_width) - self.avd = avd and (stride > 1 or is_first) - - if self.avd: - self.avd_layer = nn.AvgPool2d(3, stride, padding=1) - stride = 1 - - self.k1 = nn.Sequential( - nn.Conv2d( - group_width, group_width, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, - groups=cardinality, bias=False), - norm_layer(group_width), - ) - - self.scconv = SCConv( - group_width, group_width, stride=stride, - padding=dilation, dilation=dilation, - groups=cardinality, pooling_r=self.pooling_r, norm_layer=norm_layer) - - self.conv3 = nn.Conv2d( - group_width * 2, planes * 4, kernel_size=1, bias=False) - self.bn3 = norm_layer(planes*4) - - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.dilation = dilation - self.stride = stride - - def forward(self, x): - residual = x - - out_a= self.conv1_a(x) - out_a = self.bn1_a(out_a) - out_b = self.conv1_b(x) - out_b = self.bn1_b(out_b) - out_a = self.relu(out_a) - out_b = self.relu(out_b) - - out_a = self.k1(out_a) - out_b = self.scconv(out_b) - out_a = self.relu(out_a) - out_b = self.relu(out_b) - - if self.avd: - out_a = self.avd_layer(out_a) - out_b = self.avd_layer(out_b) - - out = self.conv3(torch.cat([out_a, out_b], dim=1)) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - -class SCNet(nn.Module): - """ SCNet Variants Definations - Parameters - ---------- - block : Block - Class for the residual block. - layers : list of int - Numbers of layers in each block. - classes : int, default 1000 - Number of classification classes. - dilated : bool, default False - Applying dilation strategy to pretrained SCNet yielding a stride-8 model. - deep_stem : bool, default False - Replace 7x7 conv in input stem with 3 3x3 conv. - avg_down : bool, default False - Use AvgPool instead of stride conv when - downsampling in the bottleneck. - norm_layer : object - Normalization layer used (default: :class:`torch.nn.BatchNorm2d`). - Reference: - - He, Kaiming, et al. "Deep residual learning for image recognition." - Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. - - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." - """ - def __init__(self, block, layers, groups=1, bottleneck_width=32, - num_classes=1000, dilated=False, dilation=1, - deep_stem=False, stem_width=64, avg_down=False, - avd=False, norm_layer=nn.BatchNorm2d): - self.cardinality = groups - self.bottleneck_width = bottleneck_width - # ResNet-D params - self.inplanes = stem_width*2 if deep_stem else 64 - self.avg_down = avg_down - self.avd = avd - - super(SCNet, self).__init__() - conv_layer = nn.Conv2d - if deep_stem: - self.conv1 = nn.Sequential( - conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), - norm_layer(stem_width), - nn.ReLU(inplace=True), - conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), - norm_layer(stem_width), - nn.ReLU(inplace=True), - conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False), - ) - else: - self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) - if dilated or dilation == 4: - self.layer3 = self._make_layer(block, 256, layers[2], stride=1, - dilation=2, norm_layer=norm_layer) - self.layer4 = self._make_layer(block, 512, layers[3], stride=1, - dilation=4, norm_layer=norm_layer) - elif dilation==2: - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilation=1, norm_layer=norm_layer) - self.layer4 = self._make_layer(block, 512, layers[3], stride=1, - dilation=2, norm_layer=norm_layer) - else: - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - norm_layer=norm_layer) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - norm_layer=norm_layer) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, norm_layer): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, - is_first=True): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - down_layers = [] - if self.avg_down: - if dilation == 1: - down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride, - ceil_mode=True, count_include_pad=False)) - else: - down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, - ceil_mode=True, count_include_pad=False)) - down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=1, bias=False)) - else: - down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False)) - down_layers.append(norm_layer(planes * block.expansion)) - downsample = nn.Sequential(*down_layers) - - layers = [] - if dilation == 1 or dilation == 2: - layers.append(block(self.inplanes, planes, stride, downsample=downsample, - cardinality=self.cardinality, - bottleneck_width=self.bottleneck_width, - avd=self.avd, dilation=1, is_first=is_first, - norm_layer=norm_layer)) - elif dilation == 4: - layers.append(block(self.inplanes, planes, stride, downsample=downsample, - cardinality=self.cardinality, - bottleneck_width=self.bottleneck_width, - avd=self.avd, dilation=2, is_first=is_first, - norm_layer=norm_layer)) - else: - raise RuntimeError("=> unknown dilation size: {}".format(dilation)) - - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, - cardinality=self.cardinality, - bottleneck_width=self.bottleneck_width, - avd=self.avd, dilation=dilation, - norm_layer=norm_layer)) - - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x2 = self.layer1(x) - x3 = self.layer2(x2) - x4 = self.layer3(x3) - x5 = self.layer4(x4) - - return x2, x3, x4, x5 - - - -def scnet50(pretrained=False, **kwargs): - """Constructs a SCNet-50 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = SCNet(SCBottleneck, [3, 4, 6, 3], - deep_stem=False, stem_width=32, avg_down=False, - avd=False, **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['scnet50']),strict=False) - return model - -def scnet50_v1d(pretrained=False, **kwargs): - """Constructs a SCNet-50_v1d model described in - `Bag of Tricks `_. - `ResNeSt: Split-Attention Networks `_. - Compared with default SCNet(SCNetv1b), SCNetv1d replaces the 7x7 conv - in the input stem with three 3x3 convs. And in the downsampling block, - a 3x3 avg_pool with stride 2 is added before conv, whose stride is - changed to 1. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = SCNet(SCBottleneck, [3, 4, 6, 3], - deep_stem=True, stem_width=32, avg_down=True, - avd=True, **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['scnet50_v1d']),strict=False) - return model diff --git a/ptocr/model/backbone/rec_crnn_backbone.py b/ptocr/model/backbone/rec_crnn_backbone.py deleted file mode 100644 index c401800..0000000 --- a/ptocr/model/backbone/rec_crnn_backbone.py +++ /dev/null @@ -1,79 +0,0 @@ -#-*- coding:utf-8 _*- -""" -@author:fxw -@file: crnn_backbone.py -@time: 2020/10/12 -""" -import torch -import torch.nn as nn - -class conv_bn_relu(nn.Module): - def __init__(self,in_c,out_c,k_s,s,p,with_bn=True): - super(conv_bn_relu, self).__init__() - self.conv = nn.Conv2d(in_c,out_c,k_s,s,p) - self.bn = nn.BatchNorm2d(out_c) - self.relu = nn.ReLU() - self.with_bn = with_bn - def forward(self, x): - x = self.conv(x) - if self.with_bn: - x = self.bn(x) - x = self.relu(x) - return x - -class crnn_backbone(nn.Module): - def __init__(self,is_gray): - super(crnn_backbone, self).__init__() - if(is_gray): - nc = 1 - else: - nc = 3 - base_channel = 64 - self.cnn = nn.Sequential( - conv_bn_relu(nc,base_channel,3,1,1), - nn.MaxPool2d(2,2), - conv_bn_relu(base_channel,base_channel*2,3,1,1), - nn.MaxPool2d(2, 2), - conv_bn_relu(base_channel*2,base_channel*4,3,1,1), - conv_bn_relu(base_channel*4,base_channel*4,3,1,1), - nn.MaxPool2d((2,1),(2,1)), - conv_bn_relu(base_channel*4,base_channel*8,3,1,1,with_bn=True), - conv_bn_relu(base_channel*8,base_channel*8,3,1,1,with_bn=True), - nn.MaxPool2d((2,1), (2,1)), - conv_bn_relu(base_channel*8,base_channel*8,(2,1),1,0) - # conv_bn_relu(base_channel*8,base_channel*8,2,1,0) - ) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight.data) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.) - m.bias.data.fill_(1e-4) - - def forward(self, x): - x = self.cnn(x) - return x - -def rec_crnn_backbone(pretrained=False, is_gray=False,**kwargs): - """VGG 19-layer model (configuration 'E') with batch normalization - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = crnn_backbone(is_gray) - if pretrained: - pretrained_model = torch.load('./pre_model/crnn_backbone.pth') - state = model.state_dict() - for key in state.keys(): - if key in pretrained_model.keys(): - if (key=='features.0.weight' and is_gray): - state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1) - else: - state[key] = pretrained_model[key] - model.load_state_dict(state) - return model - - diff --git a/ptocr/model/backbone/rec_mobilev3_bd.py b/ptocr/model/backbone/rec_mobilev3_bd.py new file mode 100644 index 0000000..417abba --- /dev/null +++ b/ptocr/model/backbone/rec_mobilev3_bd.py @@ -0,0 +1,279 @@ +import torch.nn.functional as F +import torch.nn as nn + + +class hswish(nn.Module): + def forward(self, x): + out = x * F.relu6(x + 3, inplace=True) / 6 + return out + +class hsigmoid(nn.Module): + def forward(self, x): + out = F.relu6(x + 3, inplace=True) / 6 + return out + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + +class ConvBNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + if_act=True, + act=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + if self.act == "hardswish": + self.hardswish = hswish() + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups) + + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.if_act: + if self.act == "relu": + x = F.relu(x) + elif self.act == "hardswish": + x = self.hardswish(x) + else: + print("The activation function({}) is selected incorrectly.". + format(self.act)) + exit() + return x + +class ResidualUnit(nn.Module): + def __init__(self, + in_channels, + mid_channels, + out_channels, + kernel_size, + stride, + use_se, + act=None): + super(ResidualUnit, self).__init__() + self.if_shortcut = stride == 1 and in_channels == out_channels + self.if_se = use_se + + self.expand_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + if_act=True, + act=act) + self.bottleneck_conv = ConvBNLayer( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=int((kernel_size - 1) // 2), + groups=mid_channels, + if_act=True, + act=act) + if self.if_se: + self.mid_se = SEModule(mid_channels) + self.linear_conv = ConvBNLayer( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + if_act=False, + act=None) + + def forward(self, inputs): + x = self.expand_conv(inputs) + x = self.bottleneck_conv(x) + if self.if_se: + x = self.mid_se(x) + x = self.linear_conv(x) + if self.if_shortcut: + x = inputs+x + return x + + +class SEModule(nn.Module): + def __init__(self, in_channels, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels // reduction, + kernel_size=1, + stride=1, + padding=0) + self.conv2 = nn.Conv2d( + in_channels=in_channels // reduction, + out_channels=in_channels, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = hsigmoid() + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = self.hardsigmoid(outputs) + return inputs * outputs + +class MobileNetV3(nn.Module): + def __init__(self, + in_channels=3, + model_name='small', + scale=0.5, + large_stride=None, + small_stride=None, + **kwargs): + super(MobileNetV3, self).__init__() + if small_stride is None: + small_stride = [1, 2, 2, 2] + if large_stride is None: + large_stride = [1, 2, 2, 2] + + assert isinstance(large_stride, list), "large_stride type must " \ + "be list but got {}".format(type(large_stride)) + assert isinstance(small_stride, list), "small_stride type must " \ + "be list but got {}".format(type(small_stride)) + assert len(large_stride) == 4, "large_stride length must be " \ + "4 but got {}".format(len(large_stride)) + assert len(small_stride) == 4, "small_stride length must be " \ + "4 but got {}".format(len(small_stride)) + + if model_name == "large": + cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, 'relu', large_stride[0]], + [3, 64, 24, False, 'relu', (large_stride[1], 1)], + [3, 72, 24, False, 'relu', 1], + [5, 72, 40, True, 'relu', (large_stride[2], 1)], + [5, 120, 40, True, 'relu', 1], + [5, 120, 40, True, 'relu', 1], + [3, 240, 80, False, 'hardswish', 1], + [3, 200, 80, False, 'hardswish', 1], + [3, 184, 80, False, 'hardswish', 1], + [3, 184, 80, False, 'hardswish', 1], + [3, 480, 112, True, 'hardswish', 1], + [3, 672, 112, True, 'hardswish', 1], + [5, 672, 160, True, 'hardswish', (large_stride[3], 1)], + [5, 960, 160, True, 'hardswish', 1], + [5, 960, 160, True, 'hardswish', 1], + ] + cls_ch_squeeze = 960 + elif model_name == "small": + cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, 'relu', (small_stride[0], 1)], + [3, 72, 24, False, 'relu', (small_stride[1], 1)], + [3, 88, 24, False, 'relu', 1], + [5, 96, 40, True, 'hardswish', (small_stride[2], 1)], + [5, 240, 40, True, 'hardswish', 1], + [5, 240, 40, True, 'hardswish', 1], + [5, 120, 48, True, 'hardswish', 1], + [5, 144, 48, True, 'hardswish', 1], + [5, 288, 96, True, 'hardswish', (small_stride[3], 1)], + [5, 576, 96, True, 'hardswish', 1], + [5, 576, 96, True, 'hardswish', 1], + ] + cls_ch_squeeze = 576 + else: + raise NotImplementedError("mode[" + model_name + + "_model] is not implemented!") + + supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25] + assert scale in supported_scale, \ + "supported scales are {} but input scale is {}".format(supported_scale, scale) + + inplanes = 16 + # conv1 + self.conv1 = ConvBNLayer( + in_channels=in_channels, + out_channels=make_divisible(inplanes * scale), + kernel_size=3, + stride=2, + padding=1, + groups=1, + if_act=True, + act='hardswish') + i = 0 + block_list = [] + inplanes = make_divisible(inplanes * scale) + for (k, exp, c, se, nl, s) in cfg: + block_list.append( + ResidualUnit( + in_channels=inplanes, + mid_channels=make_divisible(scale * exp), + out_channels=make_divisible(scale * c), + kernel_size=k, + stride=s, + use_se=se, + act=nl)) + inplanes = make_divisible(scale * c) + i += 1 + self.blocks = nn.Sequential(*block_list) + + self.conv2 = ConvBNLayer( + in_channels=inplanes, + out_channels=make_divisible(scale * cls_ch_squeeze), + kernel_size=1, + stride=1, + padding=0, + groups=1, + if_act=True, + act='hardswish') + + self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.out_channels = make_divisible(scale * cls_ch_squeeze) + + def forward(self, x): + x = self.conv1(x) + x = self.blocks(x) + x = self.conv2(x) + x = self.pool(x) + return x + + +def mobilenet_v3_small(pretrained, is_gray=False,**kwargs): + if is_gray: + in_channels = 1 + else: + in_channels = 3 + model = MobileNetV3( in_channels=in_channels, + model_name='small', + scale = 1) + if pretrained: + pass + return model + +def mobilenet_v3_large(pretrained,is_gray=False,**kwargs): + if is_gray: + in_channels = 1 + else: + in_channels = 3 + model = MobileNetV3( in_channels=in_channels, + model_name='large', + scale=1) + if pretrained: + pass + return model \ No newline at end of file diff --git a/ptocr/model/backbone/rec_vgg.py b/ptocr/model/backbone/rec_vgg.py deleted file mode 100644 index 40d1c95..0000000 --- a/ptocr/model/backbone/rec_vgg.py +++ /dev/null @@ -1,258 +0,0 @@ -#-*- coding:utf-8 _*- -""" -@author:fxw -@file: vgg.py -@time: 2020/07/24 -""" -import torch -import torch.nn as nn -import torch.utils.model_zoo as model_zoo - - -__all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', -] - - -model_urls = { - 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', - 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', - 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', - 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', - 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', - 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', - 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', - 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', -} - - -class VGG(nn.Module): - - def __init__(self, features, num_classes=1000, init_weights=True): - super(VGG, self).__init__() - self.features = features - if init_weights: - self._initialize_weights() - - def forward(self, x): - x = self.features(x) - return x - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.constant_(m.bias, 0) - - -def make_layers(cfg, is_gray=False,batch_norm=False): - layers = [] - in_channels = 3 - if is_gray: - in_channels = 1 - for v in cfg: - if v == 'M': - layers += [nn.MaxPool2d(kernel_size=2, stride=2)] - elif v=='N': - layers += [nn.MaxPool2d(kernel_size=2, stride=(2,1))] - else: - conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) - if batch_norm: - layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] - else: - layers += [conv2d, nn.ReLU(inplace=True)] - in_channels = v - return nn.Sequential(*layers) - - -cfg = { - 'A': [64, 'M', 128, 'M', 256, 256, 'N', 512, 512, 'N', 512, 512, 'N'], - 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'N', 512, 512, 'N', 512, 512, 'N'], - 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'N', 512, 512, 512, 'N', 512, 512, 512, 'N'], - 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'N', 512, 512, 512, 512, 'N'], -} - - -def vgg11(pretrained=False,is_gray=False, **kwargs): - """VGG 11-layer model (configuration "A") - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['A'],is_gray=is_gray), **kwargs) - if pretrained: - pretrained_model = model_zoo.load_url(model_urls['vgg11']) - state = model.state_dict() - for key in state.keys(): - if key in pretrained_model.keys(): - if (key=='features.0.weight' and is_gray): - state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1) - else: - state[key] = pretrained_model[key] - model.load_state_dict(state) - return model - - -def vgg11_bn(pretrained=False,is_gray=False, **kwargs): - """VGG 11-layer model (configuration "A") with batch normalization - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['A'], is_gray=is_gray,batch_norm=True), **kwargs) - if pretrained: - pretrained_model = model_zoo.load_url(model_urls['vgg11_bn']) - state = model.state_dict() - for key in state.keys(): - if key in pretrained_model.keys(): - if (key=='features.0.weight' and is_gray): - state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1) - else: - state[key] = pretrained_model[key] - model.load_state_dict(state) - return model - - -def vgg13(pretrained=False,is_gray=False, **kwargs): - """VGG 13-layer model (configuration "B") - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['B'],is_gray=is_gray), **kwargs) - if pretrained: - pretrained_model = model_zoo.load_url(model_urls['vgg13']) - state = model.state_dict() - for key in state.keys(): - if key in pretrained_model.keys(): - if (key=='features.0.weight' and is_gray): - state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1) - else: - state[key] = pretrained_model[key] - model.load_state_dict(state) - return model - - -def vgg13_bn(pretrained=False, is_gray=False,**kwargs): - """VGG 13-layer model (configuration "B") with batch normalization - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['B'],is_gray=is_gray, batch_norm=True), **kwargs) - if pretrained: - pretrained_model = model_zoo.load_url(model_urls['vgg13_bn']) - state = model.state_dict() - for key in state.keys(): - if key in pretrained_model.keys(): - if (key=='features.0.weight' and is_gray): - state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1) - else: - state[key] = pretrained_model[key] - model.load_state_dict(state) - return model - - -def vgg16(pretrained=False,is_gray=False, **kwargs): - """VGG 16-layer model (configuration "D") - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['D'],is_gray=is_gray), **kwargs) - if pretrained: - pretrained_model = model_zoo.load_url(model_urls['vgg16']) - state = model.state_dict() - for key in state.keys(): - if key in pretrained_model.keys(): - if (key=='features.0.weight' and is_gray): - state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1) - else: - state[key] = pretrained_model[key] - model.load_state_dict(state) - return model - - -def vgg16_bn(pretrained=False,is_gray=False, **kwargs): - """VGG 16-layer model (configuration "D") with batch normalization - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['D'],is_gray=is_gray, batch_norm=True), **kwargs) - if pretrained: - pretrained_model = model_zoo.load_url(model_urls['vgg16_bn']) - state = model.state_dict() - for key in state.keys(): - if key in pretrained_model.keys(): - if (key=='features.0.weight' and is_gray): - state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1) - else: - state[key] = pretrained_model[key] - model.load_state_dict(state) - return model - - -def vgg19(pretrained=False, is_gray=False,**kwargs): - """VGG 19-layer model (configuration "E") - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['E'],is_gray=is_gray), **kwargs) - if pretrained: - pretrained_model = model_zoo.load_url(model_urls['vgg19']) - state = model.state_dict() - for key in state.keys(): - if key in pretrained_model.keys(): - if (key=='features.0.weight' and is_gray): - state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1) - else: - state[key] = pretrained_model[key] - model.load_state_dict(state) - return model - - -def vgg19_bn(pretrained=False, is_gray=False,**kwargs): - """VGG 19-layer model (configuration 'E') with batch normalization - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['E'],is_gray=is_gray, batch_norm=True), **kwargs) - if pretrained: - pretrained_model = model_zoo.load_url(model_urls['vgg19_bn']) - state = model.state_dict() - for key in state.keys(): - if key in pretrained_model.keys(): - if (key=='features.0.weight' and is_gray): - state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1) - else: - state[key] = pretrained_model[key] - model.load_state_dict(state) - return model - diff --git a/ptocr/model/backbone/reg_resnet.py b/ptocr/model/backbone/reg_resnet.py deleted file mode 100644 index 832aa58..0000000 --- a/ptocr/model/backbone/reg_resnet.py +++ /dev/null @@ -1,368 +0,0 @@ -#-*- coding:utf-8 _*- -""" -@author:fxw -@file: det_resnet.py.py -@time: 2020/08/07 -""" -import torch -import torch.nn as nn -import math -import torch.utils.model_zoo as model_zoo - -BatchNorm2d = nn.BatchNorm2d - -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152'] - -model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', -} - - -def load_pre_model(model,pre_model_path): - pre_dict = torch.load(pre_model_path) - model_pre_dict = {} - for key in model.state_dict().keys(): - if('model.module.backbone.'+key in pre_dict.keys()): - model_pre_dict[key] = pre_dict['model.module.backbone.'+key] - else: - model_pre_dict[key] = model.state_dict()[key] - model.load_state_dict(model_pre_dict) - return model - - -def constant_init(module, constant, bias=0): - nn.init.constant_(module.weight, constant) - if hasattr(module, 'bias'): - nn.init.constant_(module.bias, bias) - - -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): - super(BasicBlock, self).__init__() - self.with_dcn = dcn is not None - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - self.with_modulated_dcn = False - if self.with_dcn: - fallback_on_stride = dcn.get('fallback_on_stride', False) - self.with_modulated_dcn = dcn.get('modulated', False) - # self.conv2 = conv3x3(planes, planes) - if not self.with_dcn or fallback_on_stride: - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - padding=1, bias=False) - else: - deformable_groups = dcn.get('deformable_groups', 1) - if not self.with_modulated_dcn: - from models.dcn import DeformConv - conv_op = DeformConv - offset_channels = 18 - else: - from models.dcn import ModulatedDeformConv - conv_op = ModulatedDeformConv - offset_channels = 27 - self.conv2_offset = nn.Conv2d( - planes, - deformable_groups * offset_channels, - kernel_size=3, - padding=1) - self.conv2 = conv_op( - planes, - planes, - kernel_size=3, - padding=1, - deformable_groups=deformable_groups, - bias=False) - self.bn2 = BatchNorm2d(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - # out = self.conv2(out) - if not self.with_dcn: - out = self.conv2(out) - elif self.with_modulated_dcn: - offset_mask = self.conv2_offset(out) - offset = offset_mask[:, :18, :, :] - mask = offset_mask[:, -9:, :, :].sigmoid() - out = self.conv2(out, offset, mask) - else: - offset = self.conv2_offset(out) - out = self.conv2(out, offset) - out = self.bn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): - super(Bottleneck, self).__init__() - self.with_dcn = dcn is not None - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = BatchNorm2d(planes) - fallback_on_stride = False - self.with_modulated_dcn = False - if self.with_dcn: - fallback_on_stride = dcn.get('fallback_on_stride', False) - self.with_modulated_dcn = dcn.get('modulated', False) - if not self.with_dcn or fallback_on_stride: - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=stride, padding=1, bias=False) - else: - deformable_groups = dcn.get('deformable_groups', 1) - if not self.with_modulated_dcn: - from models.dcn import DeformConv - conv_op = DeformConv - offset_channels = 18 - else: - from models.dcn import ModulatedDeformConv - conv_op = ModulatedDeformConv - offset_channels = 27 - self.conv2_offset = nn.Conv2d( - planes, deformable_groups * offset_channels, - kernel_size=3, - padding=1) - self.conv2 = conv_op( - planes, planes, kernel_size=3, padding=1, stride=stride, - deformable_groups=deformable_groups, bias=False) - self.bn2 = BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = BatchNorm2d(planes * 4) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - self.dcn = dcn - self.with_dcn = dcn is not None - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - # out = self.conv2(out) - if not self.with_dcn: - out = self.conv2(out) - elif self.with_modulated_dcn: - offset_mask = self.conv2_offset(out) - offset = offset_mask[:, :18, :, :] - mask = offset_mask[:, -9:, :, :].sigmoid() - out = self.conv2(out, offset, mask) - else: - offset = self.conv2_offset(out) - out = self.conv2(out, offset) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000, - dcn=None, stage_with_dcn=(False, False, False, False)): - self.dcn = dcn - self.stage_with_dcn = stage_with_dcn - self.inplanes = 64 - super(ResNet, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = BatchNorm2d(64) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2,1), padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer( - block, 128, layers[1], stride=2, dcn=dcn) - self.layer3 = self._make_layer( - block, 256, layers[2], stride=(2,1), dcn=dcn) - self.layer4 = self._make_layer( - block, 512, layers[3], stride=(2,1), dcn=dcn) - # self.avgpool = nn.AvgPool2d(7, stride=1) - # self.fc = nn.Linear(512 * block.expansion, num_classes) - - # self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - if self.dcn is not None: - for m in self.modules(): - if isinstance(m, Bottleneck) or isinstance(m, BasicBlock): - if hasattr(m, 'conv2_offset'): - constant_init(m.conv2_offset, 0) - - def _make_layer(self, block, planes, blocks, stride=1, dcn=None): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - BatchNorm2d(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, - stride, downsample, dcn=dcn)) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, dcn=dcn)) - - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - return x - - -def resnet18(pretrained=True, load_url=False,**kwargs): - """Constructs a ResNet-18 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) - if pretrained: - if load_url: - model.load_state_dict(model_zoo.load_url( - model_urls['resnet50']), strict=False) - else: - model = load_pre_model(model,'./pre_model/pre-trained-model-synthtext-resnet18.pth') - return model - - -def deformable_resnet18(pretrained=True,load_url=False, **kwargs): - """Constructs a ResNet-18 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(BasicBlock, [2, 2, 2, 2], - dcn=dict(modulated=True, - deformable_groups=1, - fallback_on_stride=False), - stage_with_dcn=[False, True, True, True], **kwargs) - if pretrained: - if load_url: - model.load_state_dict(model_zoo.load_url( - model_urls['resnet50']), strict=False) - else: - model = load_pre_model(model,'./pre_model/pre-trained-model-synthtext-resnet18.pth') - return model - - -def resnet34(pretrained=True, **kwargs): - """Constructs a ResNet-34 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url( - model_urls['resnet34']), strict=False) - return model - - -def resnet50(pretrained=True,load_url=False,**kwargs): - """Constructs a ResNet-50 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) - if pretrained: - if load_url: - model.load_state_dict(model_zoo.load_url( - model_urls['resnet50']), strict=False) - else: - model = load_pre_model(model,'./pre_model/pre-trained-model-synthtext-resnet50.pth') - return model - - -def deformable_resnet50(pretrained=True,load_url=False, **kwargs): - """Constructs a ResNet-50 model with deformable conv. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(Bottleneck, [3, 4, 6, 3], - dcn=dict(modulated=True, - deformable_groups=1, - fallback_on_stride=False), - stage_with_dcn=[False, True, True, True], - **kwargs) - if pretrained: - if load_url: - model.load_state_dict(model_zoo.load_url( - model_urls['resnet50']), strict=False) - else: - model = load_pre_model(model,'./pre_model/pre-trained-model-synthtext-resnet50.pth') - return model - - -def resnet101(pretrained=True, **kwargs): - """Constructs a ResNet-101 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url( - model_urls['resnet101']), strict=False) - return model - - -def resnet152(pretrained=True, **kwargs): - """Constructs a ResNet-152 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url( - model_urls['resnet152']), strict=False) - return model diff --git a/ptocr/model/backbone/reg_resnet_bd.py b/ptocr/model/backbone/reg_resnet_bd.py new file mode 100644 index 0000000..0840aa2 --- /dev/null +++ b/ptocr/model/backbone/reg_resnet_bd.py @@ -0,0 +1,267 @@ +import torch.nn.functional as F +import torch.nn as nn + +class ConvBNLayer(nn.Module): + def __init__(self,in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_relu=False, + is_vd_mode=False): + super(ConvBNLayer,self).__init__() + + self.is_vd_mode = is_vd_mode + self.is_relu = is_relu + + if is_vd_mode: + self._pool2d_avg = nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0, ceil_mode=True) + + self.conv = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1 if is_vd_mode else stride, + padding=(kernel_size - 1) // 2, + groups=groups) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self,x): + if self.is_vd_mode: + x = self._pool2d_avg(x) + x = self.bn(self.conv(x)) + if self.is_relu: + x = self.relu(x) + return x + + +class BottleneckBlock(nn.Module): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + is_relu=True) + + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + is_relu=True) + + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1) + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=stride, + is_vd_mode=not if_first and stride[0] != 1) + + self.shortcut = shortcut + + def forward(self,x): + y = self.conv0(x) + y = self.conv2(self.conv1(y)) + if self.shortcut: + short = x + else: + short = self.short(x) + y = y+short + y = F.relu(y) + return y + + +class BasicBlock(nn.Module): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False): + super(BasicBlock, self).__init__() + self.stride = stride + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + is_relu=True) + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3) + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + is_vd_mode=not if_first and stride[0] != 1) + + self.shortcut = shortcut + + def forward(self, x): + y = self.conv0(x) + y = self.conv1(y) + + if self.shortcut: + short = x + else: + short = self.short(x) + y = y+short + y = F.relu(y) + return y + +class ResNet(nn.Module): + def __init__(self, in_channels=3, layers=50, **kwargs): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_channels = [64, 256, 512, + 1024] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512] + + self.conv1_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=32, + kernel_size=3, + stride=1, + is_relu=True) + self.conv1_2 = ConvBNLayer( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + is_relu=True) + self.conv1_3 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=1, + is_relu=True) + self.pool2d_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + block_list = [] + if layers >= 50: + for block in range(len(depth)): + shortcut = False + for i in range(depth[block]): + if i == 0 and block != 0: + stride = (2, 1) + else: + stride = (1, 1) + bottleneck_block = BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=stride, + shortcut=shortcut, + if_first=block == i == 0) + shortcut = True + block_list.append(bottleneck_block) + self.out_channels = num_filters[block] + else: + for block in range(len(depth)): + shortcut = False + for i in range(depth[block]): + if i == 0 and block != 0: + stride = (2, 1) + else: + stride = (1, 1) + + basic_block = BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=stride, + shortcut=shortcut, + if_first=block == i == 0) + shortcut = True + block_list.append(basic_block) + self.out_channels = num_filters[block] + self.block = nn.Sequential(*block_list) + self.out_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + + def forward(self, x): + y = self.conv1_1(x) + y = self.conv1_2(y) + y = self.conv1_3(y) + y = self.pool2d_max(y) + for block in self.block: + y = block(y) + y = self.out_pool(y) + return y + + + + +def resnet18(pretrained=False, is_gray=False,**kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if is_gray: + in_channels = 1 + else: + in_channels = 3 + model = ResNet(in_channels=in_channels, layers=18, **kwargs) + if pretrained: + pass + return model + +def resnet34(pretrained=False, is_gray=False, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if is_gray: + in_channels = 1 + else: + in_channels = 3 + model = ResNet(in_channels=in_channels, layers=34, **kwargs) + if pretrained: + pass + return model + +def resnet50(pretrained=False, is_gray=False, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if is_gray: + in_channels = 1 + else: + in_channels = 3 + model = ResNet(in_channels=in_channels, layers=50, **kwargs) + if pretrained: + pass + return model \ No newline at end of file diff --git a/ptocr/model/head/__pycache__/__init__.cpython-36.pyc b/ptocr/model/head/__pycache__/__init__.cpython-36.pyc index c7e77d1..6776b1f 100644 Binary files a/ptocr/model/head/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/head/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc b/ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc index 0820132..868b86e 100644 Binary files a/ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc and b/ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc differ diff --git a/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc b/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc index 1342a4b..5852ece 100644 Binary files a/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc and b/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc differ diff --git a/ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc b/ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc index 11e7a9e..5954053 100644 Binary files a/ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc and b/ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc differ diff --git a/ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc b/ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc index 86a2251..297f9e9 100644 Binary files a/ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc and b/ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc differ diff --git a/ptocr/model/head/__pycache__/rec_FCHead.cpython-36.pyc b/ptocr/model/head/__pycache__/rec_FCHead.cpython-36.pyc new file mode 100644 index 0000000..b66d2eb Binary files /dev/null and b/ptocr/model/head/__pycache__/rec_FCHead.cpython-36.pyc differ diff --git a/ptocr/model/head/det_SASTHead.py b/ptocr/model/head/det_SASTHead.py index 98af128..dea784b 100644 --- a/ptocr/model/head/det_SASTHead.py +++ b/ptocr/model/head/det_SASTHead.py @@ -76,8 +76,8 @@ def __init__(self): super(FPN_Down_Fusion, self).__init__() self.fpn_down_conv1 = ConvBnRelu(3, 32, 1, 1, 0, with_relu=False) -# self.fpn_down_conv2 = ConvBnRelu(128, 64, 1, 1, 0, with_relu=False) # for 3*3 - self.fpn_down_conv2 = ConvBnRelu(64, 64, 1, 1, 0, with_relu=False) # for 7*7 + self.fpn_down_conv2 = ConvBnRelu(128, 64, 1, 1, 0, with_relu=False) # for 3*3 +# self.fpn_down_conv2 = ConvBnRelu(64, 64, 1, 1, 0, with_relu=False) # for 7*7 self.fpn_down_conv3 = ConvBnRelu(256, 128, 1, 1, 0, with_relu=False) self.fpn_down_conv4 = ConvBnRelu(32, 64, 3, 2, 1, with_relu=False) diff --git a/ptocr/model/head/rec_CRNNHead.py b/ptocr/model/head/rec_CRNNHead.py index 36c4f54..6bc452e 100644 --- a/ptocr/model/head/rec_CRNNHead.py +++ b/ptocr/model/head/rec_CRNNHead.py @@ -1,107 +1,121 @@ -#-*- coding:utf-8 _*- +# -*- coding:utf-8 _*- """ @author:fxw @file: crnn_head.py @time: 2020/07/24 """ +import torch import torch.nn as nn from torch.nn import init -class SeModule(nn.Module): - def __init__(self, in_size, reduction=4): - super(SeModule, self).__init__() - - self.se = nn.Sequential( - nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(in_size // reduction), - nn.ReLU(inplace=True), - nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(in_size), - nn.Sigmoid()) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - init.kaiming_normal_(m.weight, mode='fan_out') - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): - init.constant_(m.weight, 1) - init.constant_(m.bias, 0) + +class channelattention(nn.Module): + def __init__(self, time_step): + super(channelattention, self).__init__() + self.attention = nn.Sequential(nn.Linear(time_step, 8), + nn.Linear(8, 1)) + + def forward(self, x): + x = x.permute(0, 2, 1) + att = self.attention(x) + att = torch.sigmoid(att) + out = att * x + return out.permute(2, 0, 1) + + +class timeattention(nn.Module): + def __init__(self, inchannel): + super(timeattention, self).__init__() + self.attention = nn.Sequential(nn.Linear(inchannel, inchannel // 4), + nn.Linear(inchannel // 4, inchannel // 8), + nn.Linear(inchannel // 8, 1)) def forward(self, x): - return x * self.se(x) - + att = self.attention(x) + att = torch.sigmoid(att) + out = att * x + return out.permute(1, 0, 2) + + class BLSTM(nn.Module): def __init__(self, nIn, nHidden, nOut): super(BLSTM, self).__init__() - self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) + self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, batch_first=True) self.embedding = nn.Linear(nHidden * 2, nOut) def forward(self, input): + if not hasattr(self, '_flattened'): + self.rnn.flatten_parameters() + setattr(self, '_flattened', True) recurrent, _ = self.rnn(input) - T, b, h = recurrent.size() - t_rec = recurrent.view(T * b, h) + b, T, h = recurrent.size() + t_rec = recurrent.contiguous().view(b * T, h) - output = self.embedding(t_rec) # [T * b, nOut] - output = output.view(T, b, -1) + output = self.embedding(t_rec) # [b * T, nOut] + output = output.contiguous().view(b, T, -1) return output + class CRNN_Head(nn.Module): - def __init__(self,use_conv=False, - use_attention=False, + def __init__(self, use_attention=False, use_lstm=True, - lstm_num=2, - inchannel=512, - hiddenchannel=128, - classes=1000): - super(CRNN_Head,self).__init__() + time_step=25, + lstm_num=2, + inchannel=512, + hiddenchannel=128, + classes=1000): + super(CRNN_Head, self).__init__() self.use_lstm = use_lstm self.lstm_num = lstm_num - self.use_conv = use_conv - if use_attention: - self.attention = SeModule(inchannel) self.use_attention = use_attention - if(use_lstm): - assert lstm_num>0 ,Exception('lstm_num need to more than 0 if use_lstm = True') + + if use_attention: + self.channel_attention = channelattention(time_step=time_step) + self.time_attention = timeattention(inchannel) + if (use_lstm): + assert lstm_num > 0, Exception('lstm_num need to more than 0 if use_lstm = True') for i in range(lstm_num): - if(i==0): - if(lstm_num==1): - setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel, hiddenchannel,classes)) + if (i == 0): + if (lstm_num == 1): + setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel, hiddenchannel, classes)) else: - setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel,hiddenchannel,hiddenchannel)) - elif(i==lstm_num-1): + setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel, hiddenchannel, hiddenchannel)) + elif (i == lstm_num - 1): setattr(self, 'lstm_{}'.format(i + 1), BLSTM(hiddenchannel, hiddenchannel, classes)) else: setattr(self, 'lstm_{}'.format(i + 1), BLSTM(hiddenchannel, hiddenchannel, hiddenchannel)) - elif(use_conv): - self.out = nn.Conv2d(inchannel, classes, kernel_size=1, padding=0) else: - self.out = nn.Linear(inchannel,classes) + self.out = nn.Linear(inchannel, classes) def forward(self, x): b, c, h, w = x.size() assert h == 1, "the height of conv must be 1" - - if self.use_attention: - x = self.attention(x) - - if(self.use_conv): - x = self.out(x) - x = x.squeeze(2) - x = x.permute(2, 0, 1) - return x - + x = x.squeeze(2) - x = x.permute(2, 0, 1) # [w, b, c] + x = x.permute(0, 2, 1) # [b, w, c] + + ############ + if self.use_attention: + x = self.channel_attention(x) + x = self.time_attention(x) + + ############ + + feau = [] if self.use_lstm: for i in range(self.lstm_num): x = getattr(self, 'lstm_{}'.format(i + 1))(x) + feau.append(x) else: + feau.append(x) x = self.out(x) - return x - + + return x, feau + + def backward_hook(self, module, grad_input, grad_output): for g in grad_input: g[g != g] = 0 \ No newline at end of file diff --git a/ptocr/model/head/rec_FCHead.py b/ptocr/model/head/rec_FCHead.py new file mode 100644 index 0000000..6f6d95a --- /dev/null +++ b/ptocr/model/head/rec_FCHead.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn + +class FCModule(nn.Module): + """FCModule + Args: + """ + def __init__(self, + in_channels, + out_channels, + bias=True, + activation='relu', + inplace=True, + dropout=None, + order=('fc', 'act')): + super(FCModule, self).__init__() + self.order = order + self.activation = activation + self.inplace = inplace + + self.with_activatation = activation is not None + self.with_dropout = dropout is not None + + self.fc = nn.Linear(in_channels, out_channels, bias) + + # build activation layer + if self.with_activatation: + # TODO: introduce `act_cfg` and supports more activation layers + if self.activation not in ['relu', 'tanh']: + raise ValueError('{} is currently not supported.'.format( + self.activation)) + if self.activation == 'relu': + self.activate = nn.ReLU(inplace=inplace) + elif self.activation == 'tanh': + self.activate = nn.Tanh() + + if self.with_dropout: + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x): + if self.order == ('fc', 'act'): + x = self.fc(x) + + if self.with_activatation: + x = self.activate(x) + elif self.order == ('act', 'fc'): + if self.with_activatation: + x = self.activate(x) + x = self.fc(x) + + if self.with_dropout: + x = self.dropout(x) + + return x + +class FCModules(nn.Module): + """FCModules + Args: + """ + def __init__(self, + in_channels, + out_channels, + bias=True, + activation='relu', + inplace=True, + dropouts=None, + num_fcs=1): + super().__init__() + + if dropouts is not None: + assert num_fcs == len(dropouts) + dropout = dropouts[0] + else: + dropout = None + + layers = [FCModule(in_channels, out_channels, bias, activation, inplace, dropout)] + for ii in range(1, num_fcs): + if dropouts is not None: + dropout = dropouts[ii] + else: + dropout = None + layers.append(FCModule(out_channels, out_channels, bias, activation, inplace, dropout)) + + self.block = nn.Sequential(*layers) + + def forward(self, x): + feat = self.block(x) + return feat + + +class FC_Head(nn.Module): + def __init__(self,in_channels, + out_channels,max_length,num_class): + super(FC_Head,self).__init__() + self.adpooling = nn.AdaptiveAvgPool2d(1) + self.fc_end = FCModules(in_channels=in_channels,out_channels=out_channels) + self.fc_out = nn.Linear(out_channels,(num_class+1)*(max_length+1)) + self.num_class = num_class + self.max_length = max_length + def forward(self,x): + x = self.adpooling(x) + x = x.view(x.shape[0],-1) + x = self.fc_end(x) + x1 = self.fc_out(x) + x2 = x1.view(x1.shape[0],self.max_length+1,self.num_class+1) + return x2,x1 \ No newline at end of file diff --git a/ptocr/model/loss/__pycache__/__init__.cpython-36.pyc b/ptocr/model/loss/__pycache__/__init__.cpython-36.pyc index 5584573..6c83201 100644 Binary files a/ptocr/model/loss/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/loss/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc index 20f2ca5..ca9da33 100644 Binary files a/ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc and b/ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc differ diff --git a/ptocr/model/loss/__pycache__/centerloss.cpython-36.pyc b/ptocr/model/loss/__pycache__/centerloss.cpython-36.pyc new file mode 100644 index 0000000..cf248be Binary files /dev/null and b/ptocr/model/loss/__pycache__/centerloss.cpython-36.pyc differ diff --git a/ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc index bbc0036..d9b53cc 100644 Binary files a/ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc and b/ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc differ diff --git a/ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc index 648b897..d51b198 100644 Binary files a/ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc and b/ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc differ diff --git a/ptocr/model/loss/__pycache__/fc_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/fc_loss.cpython-36.pyc new file mode 100644 index 0000000..ae1ccdc Binary files /dev/null and b/ptocr/model/loss/__pycache__/fc_loss.cpython-36.pyc differ diff --git a/ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc index d7dfd62..bdc5329 100644 Binary files a/ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc and b/ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc differ diff --git a/ptocr/model/loss/basical_loss.py b/ptocr/model/loss/basical_loss.py index 649858a..0c74cd3 100644 --- a/ptocr/model/loss/basical_loss.py +++ b/ptocr/model/loss/basical_loss.py @@ -6,8 +6,42 @@ """ import torch import torch.nn as nn +import torch.nn.functional as F import numpy as np +class MulClassLoss(nn.Module): + def __init__(self, ): + super(MulClassLoss, self).__init__() + + def forward(self,pre_score,gt_score,n_class): + gt_score = gt_score.reshape(-1) + index = gt_score>0 + if index.sum()>0: + pre_score = pre_score.permute(0,2,3,1).reshape(-1,n_class) + gt_score = gt_score[index] + gt_score = gt_score - 1 + pre_score = pre_score[index] + class_loss = F.cross_entropy(pre_score,gt_score.long(), ignore_index=-1) + else: + class_loss = torch.tensor(0.0).cuda() + return class_loss + + +class CrossEntropyLoss(nn.Module): + + def __init__(self, weight=None, size_average=None, ignore_index=-100, + reduce=None, reduction='mean'): + super(CrossEntropyLoss, self).__init__() + self.criteron = nn.CrossEntropyLoss(weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction) + + def forward(self, pred, target, *args): + return self.criteron(pred.contiguous().view(-1, pred.shape[-1]), target.to(pred.device).contiguous().view(-1)) + + class DiceLoss(nn.Module): def __init__(self,eps=1e-6): super(DiceLoss,self).__init__() @@ -183,10 +217,12 @@ def forward(self, return balance_loss -def focal_ctc_loss(ctc_loss,alpha=0.25,gamma=0.5): # 0.99,1 +def focal_ctc_loss(ctc_loss,alpha=0.95,gamma=1): # 0.99,1 +# import pdb +# pdb.set_trace() prob = torch.exp(-ctc_loss) focal_loss = alpha*(1-prob).pow(gamma)*ctc_loss - return focal_loss.mean() + return focal_loss.sum() class focal_bin_cross_entropy(nn.Module): diff --git a/ptocr/model/loss/ctc_loss.py b/ptocr/model/loss/ctc_loss.py index 3aa8c0f..cde7a8f 100644 --- a/ptocr/model/loss/ctc_loss.py +++ b/ptocr/model/loss/ctc_loss.py @@ -1,19 +1,22 @@ -# from torch.nn import CTCLoss as PytorchCTCLoss -from warpctc_pytorch import CTCLoss as PytorchCTCLoss import torch.nn as nn -from .basical_loss import focal_ctc_loss - class CTCLoss(nn.Module): def __init__(self,config): super(CTCLoss,self).__init__() -# self.criterion = PytorchCTCLoss(reduction = config['loss']['reduction']) - self.criterion = PytorchCTCLoss() self.config = config + if config['loss']['ctc_type'] == 'warpctc': + from warpctc_pytorch import CTCLoss as PytorchCTCLoss + self.criterion = PytorchCTCLoss() + else: + from torch.nn import CTCLoss as PytorchCTCLoss + self.criterion = PytorchCTCLoss(reduction = 'none') def forward(self,pre_batch,gt_batch): preds,preds_size = pre_batch['preds'],pre_batch['preds_size'] labels,labels_len = gt_batch['labels'],gt_batch['labels_len'] + if self.config['loss']['ctc_type'] != 'warpctc': + preds = preds.log_softmax(2).requires_grad_() # torch.ctcloss loss = self.criterion(preds, labels, preds_size, labels_len) - if self.config['loss']['reduction']=='none': - loss = focal_ctc_loss(loss) + if self.config['loss']['use_ctc_weight']: + loss = gt_batch['ctc_loss_weight']*loss.cuda() + loss = loss.sum() return loss/self.config['trainload']['batch_size'] \ No newline at end of file diff --git a/ptocr/model/loss/db_loss.py b/ptocr/model/loss/db_loss.py index d8ed9e7..38dc872 100644 --- a/ptocr/model/loss/db_loss.py +++ b/ptocr/model/loss/db_loss.py @@ -6,7 +6,11 @@ """ import torch import torch.nn as nn -from .basical_loss import MaskL1Loss,BalanceCrossEntropyLoss,DiceLoss,FocalCrossEntropyLoss +from .basical_loss import MaskL1Loss,BalanceCrossEntropyLoss,DiceLoss,FocalCrossEntropyLoss,MulClassLoss + + + + class DBLoss(nn.Module): def __init__(self, l1_scale=10, bce_scale=1,eps=1e-6): super(DBLoss, self).__init__() @@ -27,4 +31,31 @@ def forward(self, pred_bach, gt_batch): metrics.update(**l1_metric) else: loss = bce_loss + return loss, metrics + +class DBLossMul(nn.Module): + def __init__(self, n_class,l1_scale=10, bce_scale=1, class_scale=1,eps=1e-6): + super(DBLossMul, self).__init__() + self.dice_loss = DiceLoss(eps) + self.l1_loss = MaskL1Loss() + self.bce_loss = BalanceCrossEntropyLoss() + self.class_loss = MulClassLoss() + self.l1_scale = l1_scale + self.bce_scale = bce_scale + self.class_scale = class_scale + self.n_class = n_class + + def forward(self, pred_bach, gt_batch): + bce_loss = self.bce_loss(pred_bach['binary'][:,0], gt_batch['gt'], gt_batch['mask']) + class_loss = self.class_loss(pred_bach['binary_class'] ,gt_batch['gt_class'],self.n_class) + metrics = dict(loss_bce=bce_loss) + if 'thresh' in pred_bach: + l1_loss, l1_metric = self.l1_loss(pred_bach['thresh'][:,0], gt_batch['thresh_map'], gt_batch['thresh_mask']) + dice_loss = self.dice_loss(pred_bach['thresh_binary'][:,0], gt_batch['gt'], gt_batch['mask']) + metrics['loss_thresh'] = dice_loss + metrics['loss_class'] = class_loss + loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale + class_loss * self.class_scale + metrics.update(**l1_metric) + else: + loss = bce_loss return loss, metrics \ No newline at end of file diff --git a/ptocr/model/loss/fc_loss.py b/ptocr/model/loss/fc_loss.py new file mode 100644 index 0000000..293b7e2 --- /dev/null +++ b/ptocr/model/loss/fc_loss.py @@ -0,0 +1,14 @@ +import torch +import torch.nn as nn +from .basical_loss import CrossEntropyLoss + +class FCLoss(nn.Module): + def __init__(self,ignore_index = -1): + super(FCLoss, self).__init__() + self.cross_entropy_loss = CrossEntropyLoss(ignore_index = ignore_index) + + + def forward(self, pred_bach, gt_batch): + loss = self.cross_entropy_loss(pred_bach['pred'],gt_batch['gt']) + metrics = dict(loss_fc=loss) + return loss, metrics \ No newline at end of file diff --git a/ptocr/model/segout/__pycache__/__init__.cpython-36.pyc b/ptocr/model/segout/__pycache__/__init__.cpython-36.pyc index 6eb960b..876605d 100644 Binary files a/ptocr/model/segout/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/segout/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc b/ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc index 3368407..844eabc 100644 Binary files a/ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc and b/ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc differ diff --git a/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc b/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc index 882bfe8..5de7975 100644 Binary files a/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc and b/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc differ diff --git a/ptocr/model/segout/det_DB_segout.py b/ptocr/model/segout/det_DB_segout.py index 895472e..2d9ee69 100644 --- a/ptocr/model/segout/det_DB_segout.py +++ b/ptocr/model/segout/det_DB_segout.py @@ -88,3 +88,104 @@ def forward(self, fuse,img): def step_function(self, x, y): return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) + + +class SegDetectorMul(nn.Module): + def __init__(self,n_classes = 1, + inner_channels=256, k=10, + adaptive=False, + serial=False, bias=False, + *args, **kwargs): + ''' + bias: Whether conv layers have bias or not. + adaptive: Whether to use adaptive threshold training or not. + smooth: If true, use bilinear instead of deconv. + serial: If true, thresh prediction will combine segmentation result as input. + ''' + super(SegDetectorMul, self).__init__() + self.k = k + self.serial = serial + + self.binarize = nn.Sequential( + nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), + nn.Sigmoid() + ) + + self.classhead = nn.Sequential( + nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, n_classes, 2, 2) + ) + + + self.binarize.apply(self.weights_init) + self.classhead.apply(self.weights_init) + + self.adaptive = adaptive + if adaptive: + self.thresh = self._init_thresh( + inner_channels, serial=serial,bias=bias) + self.thresh.apply(self.weights_init) + + def weights_init(self, m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + + def _init_thresh(self, inner_channels, + serial=False, bias=False): + in_channels = inner_channels + if serial: + in_channels += 1 + self.thresh = nn.Sequential( + nn.Conv2d(in_channels, inner_channels // + 4, 3, padding=1, bias=bias), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + self._init_upsample(inner_channels // 4, inner_channels // 4), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + self._init_upsample(inner_channels // 4, 1), + nn.Sigmoid() + ) + return self.thresh + + def _init_upsample(self, in_channels, out_channels): + return nn.ConvTranspose2d(in_channels, out_channels, 2, 2) + + def forward(self, fuse,img): + + binary = self.binarize(fuse) + binary_class = self.classhead(fuse) + + if self.training: + result = OrderedDict(binary=binary) + result.update(binary_class = binary_class) + else: + return binary,binary_class + + if self.adaptive and self.training: + if self.serial: + fuse = torch.cat( + (fuse, nn.functional.interpolate( + binary, fuse.shape[2:])), 1) + thresh = self.thresh(fuse) + thresh_binary = self.step_function(binary, thresh) + result.update(thresh=thresh, thresh_binary=thresh_binary) + return result + + def step_function(self, x, y): + return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) diff --git a/ptocr/optimizer.py b/ptocr/optimizer.py index 2df2fcd..9c385f1 100644 --- a/ptocr/optimizer.py +++ b/ptocr/optimizer.py @@ -6,21 +6,21 @@ """ import torch -def AdamDecay(config,model): - optimizer = torch.optim.Adam(model.parameters(), lr=config['optimizer']['base_lr'], +def AdamDecay(config,parameters): + optimizer = torch.optim.Adam(parameters, lr=config['optimizer']['base_lr'], betas=(config['optimizer']['beta1'], config['optimizer']['beta2']), weight_decay=config['optimizer']['weight_decay']) return optimizer -def SGDDecay(config,model): - optimizer = torch.optim.SGD(model.parameters(), lr=config['optimizer']['base_lr'], +def SGDDecay(config,parameters): + optimizer = torch.optim.SGD(parameters, lr=config['optimizer']['base_lr'], momentum=config['optimizer']['momentum'], weight_decay=config['optimizer']['weight_decay']) return optimizer -def RMSPropDecay(config,model): - optimizer = torch.optim.RMSprop(model.parameters(), lr=config['optimizer']['base_lr'], +def RMSPropDecay(config,parameters): + optimizer = torch.optim.RMSprop(parameters, lr=config['optimizer']['base_lr'], alpha=config['optimizer']['alpha'], weight_decay=config['optimizer']['weight_decay'], momentum=config['optimizer']['momentum']) @@ -30,11 +30,32 @@ def RMSPropDecay(config,model): def lr_poly(base_lr, epoch, max_epoch=1200, factor=0.9): return base_lr*((1-float(epoch)/max_epoch)**(factor)) + +def SGDR(lr_max,lr_min,T_cur,T_m,ratio=0.3): + """ + :param lr_max: 最大学习率 + :param lr_min: 最小学习率 + :param T_cur: 当前的epoch或iter + :param T_m: 隔多少调整的一次 + :param ratio: 最大学习率衰减比率 + :return: + """ + if T_cur % T_m == 0 and T_cur != 0: + lr_max = lr_max - lr_max * ratio + lr = lr_min+1/2*(lr_max-lr_min)*(1+math.cos((T_cur%T_m/T_m)*math.pi)) + return lr,lr_max + + def adjust_learning_rate_poly(config, optimizer, epoch): lr = lr_poly(config['optimizer']['base_lr'], epoch, config['base']['n_epoch'], config['optimizer_decay']['factor']) optimizer.param_groups[0]['lr'] = lr - + +def adjust_learning_rate_sgdr(config, optimizer, epoch): + lr,lr_max = SGDR(config['optimizer']['lr_max'],config['optimizer']['lr_min'],epoch,config['optimizer']['T_m'],config['optimizer']['ratio']) + optimizer.param_groups[0]['lr'] = lr + config['optimizer']['lr_max'] = lr_max + def adjust_learning_rate(config, optimizer, epoch): if epoch in config['optimizer_decay']['schedule']: adjust_lr = optimizer.param_groups[0]['lr'] * config['optimizer_decay']['gama'] diff --git a/ptocr/postprocess/DBpostprocess.py b/ptocr/postprocess/DBpostprocess.py index e483e89..959c270 100644 --- a/ptocr/postprocess/DBpostprocess.py +++ b/ptocr/postprocess/DBpostprocess.py @@ -216,4 +216,215 @@ def __call__(self, pred, ratio_list): boxes_batch.append(boxes) score_batch.append(score) - return boxes_batch, score_batch \ No newline at end of file + return boxes_batch, score_batch + +class DBPostProcessMul(object): + """ + The post process for Differentiable Binarization (DB). + """ + + def __init__(self, config): + self.thresh = config['postprocess']['thresh'] + self.box_thresh = config['postprocess']['box_thresh'] + self.max_candidates = config['postprocess']['max_candidates'] + self.is_poly = config['postprocess']['is_poly'] + self.unclip_ratio = config['postprocess']['unclip_ratio'] + self.min_size = config['postprocess']['min_size'] + + def polygons_from_bitmap(self, pred,classes, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + ''' + + + bitmap = _bitmap + pred = pred + height, width = bitmap.shape + boxes = [] + scores = [] + + contours, _ = cv2.findContours( + (bitmap*255).astype(np.uint8), + cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours[:self.max_candidates]: + epsilon = 0.001 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + points = approx.reshape((-1, 2)) + if points.shape[0] < 4: + continue + # _, sside = self.get_mini_boxes(contour) + # if sside < self.min_size: + # continue + score = self.box_score_fast(pred,classes, points.reshape(-1, 2)) + if self.box_thresh > score: + continue + + if points.shape[0] > 2: + box = self.unclip(points, self.unclip_ratio) + if len(box) > 1: + continue + else: + continue + box ,type_class = box.reshape(-1, 2) + _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) + if sside < self.min_size + 2: + continue + + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box.tolist()) + scores.append(score) + return boxes, scores + + def boxes_from_bitmap(self, pred,classes, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + ''' + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + boxes = np.zeros((num_contours, 4, 2), dtype=np.int16) + scores = np.zeros((num_contours, ), dtype=np.float32) + type_classes = np.zeros((num_contours, ), dtype=np.float32) + + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + score,type_class = self.box_score_fast(pred,classes, points.reshape(-1, 2)) + if self.box_thresh > score: + continue + box = self.unclip(points,self.unclip_ratio).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes[index, :, :] = box.astype(np.int16) + scores[index] = score + type_classes[index] = type_class + return boxes, scores,type_classes + + def unclip(self, box, unclip_ratio=2): + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap,classes,_box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + classes = classes[ymin:ymax + 1, xmin:xmax + 1] + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0],np.argmax(np.bincount(classes.reshape(-1).astype(np.int32))) + + def __call__(self, pred,pred_class, ratio_list): + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + classes = pred_class[:, 0, :, :] + + boxes_batch = [] + score_batch = [] + type_classes_batch = [] + for batch_index in range(pred.shape[0]): + height, width = pred.shape[-2:] + if(self.is_poly): + tmp_boxes, tmp_scores = self.polygons_from_bitmap( + pred[batch_index], classes[batch_index],segmentation[batch_index], width, height) + + boxes = [] + score = [] + for k in range(len(tmp_boxes)): + if tmp_scores[k] > self.box_thresh: + boxes.append(tmp_boxes[k]) + score.append(tmp_scores[k]) + if len(boxes) > 0: + ratio_w, ratio_h = ratio_list[batch_index] + for i in range(len(boxes)): + boxes[i] = np.array(boxes[i]) + boxes[i][:, 0] = boxes[i][:, 0] * ratio_w + boxes[i][:, 1] = boxes[i][:, 1] * ratio_h + + boxes_batch.append(boxes) + score_batch.append(score) + else: + tmp_boxes, tmp_scores,type_classes = self.boxes_from_bitmap( + pred[batch_index], classes[batch_index],segmentation[batch_index], width, height) + + boxes = [] + score = [] + _classes = [] + for k in range(len(tmp_boxes)): + if tmp_scores[k] > self.box_thresh: + boxes.append(tmp_boxes[k]) + score.append(tmp_scores[k]) + _classes.append(type_classes[k]) + if len(boxes) > 0: + boxes = np.array(boxes) + + ratio_w, ratio_h = ratio_list[batch_index] + boxes[:, :, 0] = boxes[:, :, 0] * ratio_w + boxes[:, :, 1] = boxes[:, :, 1] * ratio_h + type_classes_batch.append(_classes) + boxes_batch.append(boxes) + score_batch.append(score) + return boxes_batch,score_batch,type_classes_batch \ No newline at end of file diff --git a/ptocr/postprocess/__pycache__/DBpostprocess.cpython-36.pyc b/ptocr/postprocess/__pycache__/DBpostprocess.cpython-36.pyc new file mode 100644 index 0000000..e6bd8b5 Binary files /dev/null and b/ptocr/postprocess/__pycache__/DBpostprocess.cpython-36.pyc differ diff --git a/ptocr/postprocess/__pycache__/SASTpostprocess.cpython-36.pyc b/ptocr/postprocess/__pycache__/SASTpostprocess.cpython-36.pyc new file mode 100644 index 0000000..9343f87 Binary files /dev/null and b/ptocr/postprocess/__pycache__/SASTpostprocess.cpython-36.pyc differ diff --git a/ptocr/postprocess/__pycache__/__init__.cpython-36.pyc b/ptocr/postprocess/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..7936da5 Binary files /dev/null and b/ptocr/postprocess/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/postprocess/__pycache__/locality_aware_nms.cpython-36.pyc b/ptocr/postprocess/__pycache__/locality_aware_nms.cpython-36.pyc new file mode 100644 index 0000000..2aa05e9 Binary files /dev/null and b/ptocr/postprocess/__pycache__/locality_aware_nms.cpython-36.pyc differ diff --git a/ptocr/postprocess/dbprocess/__pycache__/__init__.cpython-36.pyc b/ptocr/postprocess/dbprocess/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..a00defc Binary files /dev/null and b/ptocr/postprocess/dbprocess/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/postprocess/lanms/__pycache__/__init__.cpython-36.pyc b/ptocr/postprocess/lanms/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..19a2476 Binary files /dev/null and b/ptocr/postprocess/lanms/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/utils/__pycache__/__init__.cpython-36.pyc b/ptocr/utils/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..1528833 Binary files /dev/null and b/ptocr/utils/__pycache__/__init__.cpython-36.pyc differ diff --git a/ptocr/utils/__pycache__/cal_iou_acc.cpython-36.pyc b/ptocr/utils/__pycache__/cal_iou_acc.cpython-36.pyc new file mode 100644 index 0000000..43a58c4 Binary files /dev/null and b/ptocr/utils/__pycache__/cal_iou_acc.cpython-36.pyc differ diff --git a/ptocr/utils/__pycache__/gen_teacher_model.cpython-36.pyc b/ptocr/utils/__pycache__/gen_teacher_model.cpython-36.pyc new file mode 100644 index 0000000..c08392a Binary files /dev/null and b/ptocr/utils/__pycache__/gen_teacher_model.cpython-36.pyc differ diff --git a/ptocr/utils/__pycache__/logger.cpython-36.pyc b/ptocr/utils/__pycache__/logger.cpython-36.pyc new file mode 100644 index 0000000..d1f7b05 Binary files /dev/null and b/ptocr/utils/__pycache__/logger.cpython-36.pyc differ diff --git a/ptocr/utils/__pycache__/metrics.cpython-36.pyc b/ptocr/utils/__pycache__/metrics.cpython-36.pyc new file mode 100644 index 0000000..69e442b Binary files /dev/null and b/ptocr/utils/__pycache__/metrics.cpython-36.pyc differ diff --git a/ptocr/utils/__pycache__/prune_script.cpython-36.pyc b/ptocr/utils/__pycache__/prune_script.cpython-36.pyc new file mode 100644 index 0000000..8f331e1 Binary files /dev/null and b/ptocr/utils/__pycache__/prune_script.cpython-36.pyc differ diff --git a/ptocr/utils/__pycache__/transform_label.cpython-36.pyc b/ptocr/utils/__pycache__/transform_label.cpython-36.pyc new file mode 100644 index 0000000..b70e4b6 Binary files /dev/null and b/ptocr/utils/__pycache__/transform_label.cpython-36.pyc differ diff --git a/ptocr/utils/__pycache__/util_function.cpython-36.pyc b/ptocr/utils/__pycache__/util_function.cpython-36.pyc new file mode 100644 index 0000000..3e86539 Binary files /dev/null and b/ptocr/utils/__pycache__/util_function.cpython-36.pyc differ diff --git a/ptocr/utils/gen_teacher_model.py b/ptocr/utils/gen_teacher_model.py index c3f9e74..5a42767 100644 --- a/ptocr/utils/gen_teacher_model.py +++ b/ptocr/utils/gen_teacher_model.py @@ -30,7 +30,6 @@ def forward(self,pre_score,gt_score,train_mask): dice_loss = torch.mean(d) return 1 - dice_loss - def GetTeacherModel(args): config = yaml.load(open(args.t_config, 'r', encoding='utf-8'), Loader=yaml.FullLoader) model = create_module(config['architectures']['model_function'])(config) @@ -41,18 +40,18 @@ def GetTeacherModel(args): class DistilLoss(nn.Module): def __init__(self): + super(DistilLoss, self).__init__() self.mse = nn.MSELoss() self.diceloss = DiceLoss() self.ignore = ['thresh'] + def forward(self, s_map, t_map): loss = 0 -# import pdb -# pdb.set_trace() for key in s_map.keys(): if(key in self.ignore): continue - loss+=self.diceloss(s_map[key],t_map[key],torch.ones(t_map[key].shape).cuda()) + loss += self.diceloss(s_map[key],t_map[key],torch.ones(t_map[key].shape).cuda()) return loss diff --git a/ptocr/utils/logger.py b/ptocr/utils/logger.py index 4ca3794..16bc333 100644 --- a/ptocr/utils/logger.py +++ b/ptocr/utils/logger.py @@ -3,6 +3,34 @@ # (C) Wei YANG 2017 from __future__ import absolute_import +import logging + +class TrainLog(object): + def __init__(self,LOG_FILE): + file_handler = logging.FileHandler(LOG_FILE) #输出到文件 + console_handler = logging.StreamHandler() #输出到控制台 + file_handler.setLevel('INFO') #error以上才输出到文件 + console_handler.setLevel('INFO') #info以上才输出到控制台 + + fmt = '%(asctime)s - %(funcName)s - %(lineno)s - %(levelname)s - %(message)s' + formatter = logging.Formatter(fmt) + file_handler.setFormatter(formatter) #设置输出内容的格式 + console_handler.setFormatter(formatter) + + logger = logging.getLogger('TrainLog') + logger.setLevel('INFO') #设置了这个才会把debug以上的输出到控制台 + + logger.addHandler(console_handler) + logger.addHandler(file_handler) + self.logger = logger + + def error(self,char): + self.logger.error(char) + def debug(self,char): + self.logger.debug(char) + def info(self,char): + self.logger.info(char) + class Logger(object): def __init__(self, fpath, title=None, resume=False): self.file = None diff --git a/ptocr/utils/transform_label.py b/ptocr/utils/transform_label.py index a0d322e..27f4e11 100644 --- a/ptocr/utils/transform_label.py +++ b/ptocr/utils/transform_label.py @@ -11,6 +11,7 @@ # import chardet import numpy as np import sys +import abc def get_keys(key_path): with open(key_path,'r',encoding='utf-8') as fid: @@ -18,7 +19,7 @@ def get_keys(key_path): lines = lines.strip('\n') return lines - +#☯ class strLabelConverter(object): """Convert between str and label. @@ -125,4 +126,73 @@ def val(self): res = 0 if self.n_count != 0: res = self.sum / float(self.n_count) - return res \ No newline at end of file + return res + + + +class BaseConverter(object): + + def __init__(self, character): + self.character = list(character) + self.dict = {} + for i, char in enumerate(self.character): + self.dict[char] = i + + @abc.abstractmethod + def train_encode(self, *args, **kwargs): + '''encode text in train phase''' + + @abc.abstractmethod + def test_encode(self, *args, **kwargs): + '''encode text in test phase''' + + @abc.abstractmethod + def decode(self, *args, **kwargs): + '''decode label to text in train and test phase''' + + + +class FCConverter(BaseConverter): + + def __init__(self, config): + batch_max_length = config['base']['max_length'] + character = get_keys(config['trainload']['key_file']) + self.character = character + list_token = ['[s]'] + ignore_token = ['[ignore]'] + list_character = list(character) + self.batch_max_length = batch_max_length + 1 + super(FCConverter, self).__init__(character=list_token + list_character + ignore_token) + self.ignore_index = self.dict[ignore_token[0]] + + def encode(self, text): + length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. + batch_text = torch.LongTensor(len(text), self.batch_max_length).fill_(self.ignore_index) + for i, t in enumerate(text): + text = list(t) + text.append('[s]') + text = [self.dict[char] for char in text] + if self.batch_max_length>=len(text): + batch_text[i][:len(text)] = torch.LongTensor(text) + else: + batch_text[i][:self.batch_max_length] = torch.LongTensor(text)[:self.batch_max_length] + batch_text_input = batch_text + batch_text_target = batch_text + + return batch_text_input, torch.IntTensor(length), batch_text_target + + def train_encode(self, text): + return self.encode(text) + + def test_encode(self, text): + return self.encode(text) + + def decode(self, text_index): + texts = [] + batch_size = text_index.shape[0] + for index in range(batch_size): + text = ''.join([self.character[i] for i in text_index[index, :]]) + text = text[:text.find('[s]')] + texts.append(text) + + return texts \ No newline at end of file diff --git a/ptocr/utils/util_function.py b/ptocr/utils/util_function.py index c2b3dba..5c8a97b 100644 --- a/ptocr/utils/util_function.py +++ b/ptocr/utils/util_function.py @@ -10,7 +10,20 @@ import cv2 import torch import numpy as np - +from PIL import Image + +def PILImageToCV(img,is_gray=False): + img = np.asarray(img) + if not is_gray: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + +def CVImageToPIL(img,is_gray=False): + if not is_gray: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = Image.fromarray(img) + return img + def create_module(module_str): tmpss = module_str.split(",") assert len(tmpss) == 2, "Error formate\ @@ -20,6 +33,22 @@ def create_module(module_str): function = getattr(somemodule, function_name) return function + +def restore_training(model_file,model,optimizer): + checkpoint = torch.load(model_file) + start_epoch = checkpoint['epoch'] + try: + model.load_state_dict(checkpoint['state_dict']) + except: + state = model.state_dict() + for key in state.keys(): + state[key] = checkpoint['state_dict'][key[7:]] + model.load_state_dict(state) + optimizer.load_state_dict(checkpoint['optimizer']) + best_acc = checkpoint['best_acc'] + return model,optimizer,start_epoch,best_acc + + def resize_image_batch(img,algorithm,side_len=1536,add_padding=True): if(algorithm=='SAST'): @@ -71,7 +100,21 @@ def resize_image(img,algorithm,side_len=736,stride = 128): resized_img = cv2.resize(img, (new_width, new_height)) return resized_img +def resize_image_crnn(img,max_width=100,side_len=32,stride =4): + height, width, _ = img.shape + new_height = side_len + + new_width = int(math.ceil(new_height / height * width / stride) * stride) + if new_width>max_width: + resized_img = cv2.resize(img, (max_width, new_height)) + else: + resized_img = cv2.resize(img, (new_width, new_height)) + resized_img = cv2.copyMakeBorder(resized_img, 0, 0, + 0, max_width-new_width, cv2.BORDER_CONSTANT, value=(0,0,0)) + return resized_img + + def save_checkpoint(state, checkpoint='checkpoint', filename='model.pth.tar'): filepath = os.path.join(checkpoint, filename) torch.save(state, filepath) @@ -89,8 +132,13 @@ def loss_mean(self): def loss_clear(self): self.loss_items = [] -def create_process_obj(algorithm,pred): +def create_process_obj(config,pred): + algorithm = config['base']['algorithm'] + if(algorithm=='DB'): + if 'n_class' in config['base'].keys(): + pred,pred_class = pred + return pred.cpu().numpy(),pred_class return pred.cpu().numpy() elif(algorithm=='SAST'): pred['f_score'] = pred['f_score'].cpu().numpy() @@ -102,10 +150,15 @@ def create_process_obj(algorithm,pred): return pred -def create_loss_bin(algorithm,use_distil=False,use_center=False): +def create_loss_bin(config,use_distil=False,use_center=False): + algorithm = config['base']['algorithm'] bin_dict = {} if(algorithm=='DB'): - keys = ['loss_total','loss_l1', 'loss_bce', 'loss_thresh'] + if 'n_class' in config['base'].keys(): + keys = ['loss_total','loss_l1', 'loss_bce','loss_class', 'loss_thresh'] + else: + keys = ['loss_total','loss_l1', 'loss_bce', 'loss_thresh'] + elif(algorithm=='PSE'): keys = ['loss_total','loss_kernel', 'loss_text'] elif(algorithm=='PAN'): @@ -117,6 +170,8 @@ def create_loss_bin(algorithm,use_distil=False,use_center=False): keys = ['loss_total','loss_ctc','loss_center'] else: keys = ['loss_ctc'] + elif (algorithm == 'FC'): + keys = ['loss_fc'] else: assert 1==2,'only support algorithm DB,SAST,PSE,PAN,CRNN !!!' @@ -144,14 +199,14 @@ def load_model(model,model_path): if torch.cuda.is_available(): model_dict = torch.load(model_path) else: - model_dict = torch.load(model_path,map_location='cpu') - + model_dict = torch.load(model_path,map_location='cpu') if('state_dict' in model_dict.keys()): model_dict = model_dict['state_dict'] try: model.load_state_dict(model_dict) except: + state = model.state_dict() for key in state.keys(): state[key] = model_dict['module.' + key] @@ -167,6 +222,21 @@ def merge_config(config,args): return config +def FreezeParameters(model,layer_name): + for name, parameter in model.named_parameters(): + if layer_name in name: + parameter.requires_grad = False + return filter(lambda p: p.requires_grad, model.parameters()) + +def ReleaseParameters(model,optimizer,layer_name): + for name, parameter in model.named_parameters(): + parameter_dict = {} + if layer_name in name: + parameter.requires_grad = True + parameter_dict['params'] = parameter + optimizer.add_param_group(parameter_dict) + return optimizer + class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): diff --git a/script/__pycache__/__init__.cpython-36.pyc b/script/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..04478db Binary files /dev/null and b/script/__pycache__/__init__.cpython-36.pyc differ diff --git a/script/__pycache__/onnx_to_tensorrt.cpython-36.pyc b/script/__pycache__/onnx_to_tensorrt.cpython-36.pyc new file mode 100644 index 0000000..d7121cb Binary files /dev/null and b/script/__pycache__/onnx_to_tensorrt.cpython-36.pyc differ diff --git a/script/create_lmdb.py b/script/create_lmdb.py index 51833ee..64e13eb 100644 --- a/script/create_lmdb.py +++ b/script/create_lmdb.py @@ -5,6 +5,7 @@ import argparse import shutil import sys +from tqdm import tqdm def checkImageIsValid(imageBin): if imageBin is None: @@ -54,7 +55,9 @@ def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkV env = lmdb.open(outputPath, map_size=1099511627776) cache = {} cnt = 1 + bar = tqdm(total=nSamples) for i in range(nSamples): + bar.update(1) imagePath = imagePathList[i] label = labelList[i] @@ -78,50 +81,38 @@ def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkV if cnt % 1000 == 0: writeCache(env, cache) cache = {} - print('Written %d / %d' % (cnt, nSamples)) +# print('Written %d / %d' % (cnt, nSamples)) cnt += 1 + bar.close() nSamples = cnt-1 cache['num-samples'] = str(nSamples) writeCache(env, cache) env.close() print('Created dataset with %d samples' % nSamples) -def read_data_from_folder(folder_path): - image_path_list = [] - label_list = [] - pics = os.listdir(folder_path) - pics.sort(key = lambda i: len(i)) - for pic in pics: - image_path_list.append(folder_path + '/' + pic) - label_list.append(pic.split('_')[0]) - return image_path_list, label_list def read_data_from_file(file_path): image_path_list = [] label_list = [] - f = open(file_path) - while True: - line1 = f.readline() - line2 = f.readline() - if not line1 or not line2: - break - line1 = line1.replace('\r', '').replace('\n', '') - line2 = line2.replace('\r', '').replace('\n', '') - image_path_list.append(line1) - label_list.append(line2) + with open(file_path,'r',encoding='utf-8') as fid: + lines = fid.readlines() + for line in lines: + line = line.replace('\r', '').replace('\n', '').split('\t') + image_path_list.append(line[0]) + label_list.append(line[1]) return image_path_list, label_list def show_demo(demo_number, image_path_list, label_list): - print ('\nShow some demo to prevent creating wrong lmdb data') print ('The first line is the path to image and the second line is the image label') + print ('###########################################################################') for i in range(demo_number): print ('image: %s\nlabel: %s\n' % (image_path_list[i], label_list[i])) + print ('###########################################################################') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--out', type = str, required = True, help = 'lmdb data output path') - parser.add_argument('--folder', type = str, help = 'path to folder which contains the images') parser.add_argument('--file', type = str, help = 'path to file which contains the image path and label') args = parser.parse_args() @@ -129,10 +120,7 @@ def show_demo(demo_number, image_path_list, label_list): image_path_list, label_list = read_data_from_file(args.file) createDataset(args.out, image_path_list, label_list) show_demo(2, image_path_list, label_list) - elif args.folder is not None: - image_path_list, label_list = read_data_from_folder(args.folder) - createDataset(args.out, image_path_list, label_list) - show_demo(2, image_path_list, label_list) + else: - print ('Please use --floder or --file to assign the input. Use -h to see more.') + print ('Please use --file to assign the input. Use -h to see more.') sys.exit() \ No newline at end of file diff --git a/script/create_lmdb_multiprocessing.py b/script/create_lmdb_multiprocessing.py new file mode 100644 index 0000000..3ed9b1d --- /dev/null +++ b/script/create_lmdb_multiprocessing.py @@ -0,0 +1,206 @@ +import os +import lmdb +import cv2 +import numpy as np +import argparse +import shutil +import sys +from tqdm import tqdm +import time +import six +from PIL import Image +from multiprocessing import Process + +def get_dict(char_list): + char_dict={} + for item in char_list: + if item in char_dict.keys(): + char_dict[item]+=1 + else: + char_dict[item]=1 + return char_dict + +def checklmdb(args): + env = lmdb.open( + args.out, + max_readers=2, + readonly=True, + lock=False, + readahead=False, + meminit=False) + + with env.begin(write=False) as txn: + nSamples = int(txn.get('num-samples'.encode('utf-8'))) + print('Check lmdb ok!!!') + print('lmdb Have {} samples'.format(nSamples)) + print('Print 5 samples:') + + for index in range(1,nSamples+1): + if index>5: + break + img_key = 'image-%09d' % index + imgbuf = txn.get(img_key.encode('utf-8')) + label_key = 'label-%09d' % index + label = txn.get(label_key.encode('utf-8')).decode() + print(index,label) + + +def checkImageIsValid(imageBin): + if imageBin is None: + return False + + try: + imageBuf = np.fromstring(imageBin, dtype=np.uint8) + img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) + imgH, imgW = img.shape[0], img.shape[1] + except: + return False + else: + if imgH * imgW == 0: + return False + + return True + + +def writeCache(env, cache): + with env.begin(write=True) as txn: + for k, v in cache.items(): + if type(k) == str: + k = k.encode() + if type(v) == str: + v = v.encode() + txn.put(k,v) + + +def write(imagePathList,labelList,env,start,end): + cache = {} + cnt = 1 + bar = tqdm(total=end-start) + checkValid=True + lexiconList=None + for i in range(end-start): + bar.update(1) + imagePath = imagePathList[start+i] + label = labelList[start+i] + + if not os.path.exists(imagePath): + print('%s does not exist' % imagePath) + continue + with open(imagePath, 'rb') as f: + imageBin = f.read() + if checkValid: + if not checkImageIsValid(imageBin): + print('%s is not a valid image' % imagePath) + continue + + imageKey = 'image-%09d' % (start+cnt) + labelKey = 'label-%09d' % (start+cnt) + + cache[imageKey] = imageBin + cache[labelKey] = label + if lexiconList: + lexiconKey = 'lexicon-%09d' % cnt + cache[lexiconKey] = ' '.join(lexiconList[i]) + if (start+cnt) % 1000 == 0: + writeCache(env, cache) + cache = {} + cnt += 1 + bar.close() + writeCache(env, cache) + +def createDataset(outputPath, imagePathList, labelList, num=1,lexiconList=None, checkValid=True): + """ + Create LMDB dataset for CRNN training. + ARGS: + outputPath : LMDB output path + imagePathList : list of image path + labelList : list of corresponding groundtruth texts + lexiconList : (optional) list of lexicon lists + checkValid : if true, check the validity of every image + """ + # If lmdb file already exists, remove it. Or the new data will add to it. + if os.path.exists(outputPath): + shutil.rmtree(outputPath) + os.makedirs(outputPath) + else: + os.makedirs(outputPath) + + assert (len(imagePathList) == len(labelList)) + nSamples = len(imagePathList) + env = lmdb.open(outputPath, map_size=1099511627776) + cache = {} + index = [] + + if nSamples%num==0: + step = nSamples//num + for i in range(num): + index.append([i*step,(i+1)*step]) + else: + step = nSamples//num + for i in range(num): + index.append([i*step,(i+1)*step]) + index.append([num*step,nSamples]) + + p_list = [] + if nSamples%num==0: + for i in range(num): + p = Process(target=write,args=(imagePathList,labelList,env,index[i][0],index[i][1])) + p_list.append(p) + p.start() + else: + for i in range(num+1): + p = Process(target=write,args=(imagePathList,labelList,env,index[i][0],index[i][1])) + p_list.append(p) + p.start() + for p in p_list: + p.join() + cache['num-samples'] = str(nSamples) + writeCache(env, cache) + env.close() + print('Created dataset with %d samples' % nSamples) + + +def read_data_from_file(file_path): + image_path_list = [] + label_list = [] + with open(file_path,'r',encoding='utf-8') as fid: + lines = fid.readlines() + for line in lines: + line = line.replace('\r', '').replace('\n', '').split('\t') + image_path_list.append(line[0]) + label_list.append(line[1]) + + return image_path_list, label_list + +def show_demo(demo_number, image_path_list, label_list): + print ('The first line is the path to image and the second line is the image label') + print ('###########################################################################') + for i in range(demo_number): + print ('image: %s\nlabel: %s\n' % (image_path_list[i], label_list[i])) + print ('###########################################################################') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--out', type = str, required = True, help = 'lmdb data output path') + parser.add_argument('--file', type = str,required = True, help = 'path to file which contains the image path and label') + parser.add_argument('--num_process', type = int, required = True, help = 'num_process to do') + args = parser.parse_args() + + if args.file is not None: + image_path_list, label_list = read_data_from_file(args.file) + show_demo(2, image_path_list, label_list) + s_time = time.time() + createDataset(args.out, image_path_list, label_list,num=args.num_process) + print('cost_time:'+str(time.time()-s_time)+'s') + else: + print ('Please use --file to assign the input. Use -h to see more.') + sys.exit() + print('lmdb generate ok!!!!') + checklmdb(args) + + + + + + + \ No newline at end of file diff --git a/script/get_train_list.py b/script/get_train_list.py index 6e5872d..0b481b9 100644 --- a/script/get_train_list.py +++ b/script/get_train_list.py @@ -26,4 +26,5 @@ def gen_train_file(args): parser.add_argument('--label_path', nargs='?', type=str, default=None) parser.add_argument('--img_path', nargs='?', type=str, default=None) parser.add_argument('--save_path', nargs='?', type=str, default=None) - args = parser.parse_args() \ No newline at end of file + args = parser.parse_args() + gen_train_file(args) \ No newline at end of file diff --git a/script/warp_polar.py b/script/warp_polar.py new file mode 100644 index 0000000..00399a6 --- /dev/null +++ b/script/warp_polar.py @@ -0,0 +1,67 @@ +#-*- coding:utf-8 _*- +""" +@author:fxw +@file: tt.py +@time: 2020/12/25 +""" +import cv2 +import numpy as np +import sys + +#实现图像的极坐标的转换 center代表及坐标变换中心‘;r是一个二元元组,代表最大与最小的距离;theta代表角度范围 +#rstep代表步长; thetastap代表角度的变化步长 +def polar(image,center,r,theta=(70,360+70),rstep=0.8,thetastep=360.0/(360*2)): + #得到距离的最小值、最大值 + minr,maxr=r + #角度的最小范围 + mintheta,maxtheta=theta + #输出图像的高、宽 O:指定形状类型的数组float64 + H=int((maxr-minr)/rstep)+1 + W=int((maxtheta-mintheta)/thetastep)+1 + O=125*np.ones((H,W,3),image.dtype) + #极坐标转换 利用tile函数实现W*1铺成的r个矩阵 并对生成的矩阵进行转置 + r=np.linspace(minr,maxr,H) + r=np.tile(r,(W,1)) + r=np.transpose(r) + theta=np.linspace(mintheta,maxtheta,W) + theta=np.tile(theta,(H,1)) + x,y=cv2.polarToCart(r,theta,angleInDegrees=True) + #最近插值法 + for i in range(H): + for j in range(W): + px=int(round(x[i][j])+cx) + py=int(round(y[i][j])+cy) + if((px>=0 and px<=w-1) and (py>=0 and py<=h-1)): + O[i][j][0]=image[py][px][0] + O[i][j][1]=image[py][px][1] + O[i][j][2]=image[py][px][2] + + return O + +import time +if __name__=="__main__": + img = cv2.imread(r"C:\Users\fangxuwei\Desktop\111.jpg") + # 传入的图像宽:600 高:400 + h, w = img.shape[:2] + print("h:%s w:%s"%(h,w)) + # 极坐标的变换中心(300,200) + # cx, cy = h//2, w//2 + cx, cy = 204, 201 + # cx, cy = 131, 123 + # 圆的半径为10 颜色:灰 最小位数3 + cv2.circle(img, (int(cx), int(cy)), 10, (255, 0, 0, 0), 3) + s = time.time() + L = polar(img, (cx, cy), (h//3, w//2)) + # 旋转 + L = cv2.flip(L, 0) + print(time.time()-s) + + # 显示与输出 + cv2.imshow('img', img) + cv2.imshow('O', L) + cv2.waitKey(0) + + + + + diff --git a/tools/__pycache__/MarginLoss.cpython-36.pyc b/tools/__pycache__/MarginLoss.cpython-36.pyc new file mode 100644 index 0000000..34f1559 Binary files /dev/null and b/tools/__pycache__/MarginLoss.cpython-36.pyc differ diff --git a/tools/__pycache__/__init__.cpython-36.pyc b/tools/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..f3c0dad Binary files /dev/null and b/tools/__pycache__/__init__.cpython-36.pyc differ diff --git a/tools/cal_rescall/__pycache__/__init__.cpython-36.pyc b/tools/cal_rescall/__pycache__/__init__.cpython-36.pyc index a36e979..22329de 100644 Binary files a/tools/cal_rescall/__pycache__/__init__.cpython-36.pyc and b/tools/cal_rescall/__pycache__/__init__.cpython-36.pyc differ diff --git a/tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc b/tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc index 9b788ab..1ba789f 100644 Binary files a/tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc and b/tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc differ diff --git a/tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc b/tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc index bca5b29..152a5c3 100644 Binary files a/tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc and b/tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc differ diff --git a/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc b/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc index 66f438a..e52d2ad 100644 Binary files a/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc and b/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc differ diff --git a/tools/cal_rescall/__pycache__/script.cpython-36.pyc b/tools/cal_rescall/__pycache__/script.cpython-36.pyc index 6eebef9..8a8f59f 100644 Binary files a/tools/cal_rescall/__pycache__/script.cpython-36.pyc and b/tools/cal_rescall/__pycache__/script.cpython-36.pyc differ diff --git a/tools/det_infer.py b/tools/det_infer.py index bc9348e..94986d4 100644 --- a/tools/det_infer.py +++ b/tools/det_infer.py @@ -16,6 +16,7 @@ import numpy as np from tqdm import tqdm import onnxruntime +import torch.nn.functional as F import torchvision.transforms as transforms from ptocr.utils.util_function import create_module,resize_image_batch from ptocr.utils.util_function import create_process_obj,create_dir,load_model @@ -115,19 +116,28 @@ def infer_img(self,ori_imgs): out = output[0].reshape(int(args.batch_size),7,test_size,test_size) out = torch.Tensor(out) - -# import pdb -# pdb.set_trace() - + if isinstance(out,dict): img_num = out['f_score'].shape[0] + elif isinstance(out,tuple): + img_num = out[0].shape[0] else: img_num = out.shape[0] if(self.config['base']['algorithm']=='SAST'): - scales = [(scale[0],scale[1],ori_imgs[i].shape[0],ori_imgs[i].shape[1]) for scale in scales] + scales = [(1.0/scales[i][1],1.0/scales[i][0],ori_imgs[i].shape[0],ori_imgs[i].shape[1]) for i in range(len(scales))] - out = create_process_obj(self.config['base']['algorithm'],out) + out = create_process_obj(self.config,out) + + if 'n_class' in self.config['base'].keys(): + out,out_class = out + b,_,w,h = out_class.shape + out_class = out_class.permute(0,2,3,1).reshape(-1,self.config['base']['n_class']) + out_class = F.softmax(out_class,-1) + out_class = out_class.max(1)[1].reshape(b,w,h).unsqueeze(1) + bbox_batch, score_batch,class_batch = self.img_process(out,out_class.cpu().numpy(), scales) + return bbox_batch,score_batch,class_batch + bbox_batch, score_batch = self.img_process(out, scales) return bbox_batch,score_batch @@ -145,6 +155,28 @@ def InferOneImg(bin,img,image_name,save_path): fid_res.write(bbox_str) cv2.imwrite(os.path.join(save_path, 'result_img', image_name[i] + '.jpg'), img_show) +color_rgb = [(0,0,255),(0,255,0),(255,0,0),(255,255,0),(255,0,255),(0,255,255),(156,102,31),(255,192,203),(160,32,240),(115,74,18)] + +def InferOneImgMul(bin,img,image_name,save_path,n_class): + bbox_batch, score_batch,class_batch = bin.infer_img(img) + colors = {} + for i in range(1,n_class+1): + colors[str(i-1)] = color_rgb[i-1] + for i in range(len(bbox_batch)): + img_show = img[i].copy() + with open(os.path.join(save_path, 'result_txt', 'res_' + image_name[i] + '.txt'), 'w+', encoding='utf-8') as fid_res: + bboxes = bbox_batch[i] + classes = class_batch[i] + tag = 0 + for bbox in bboxes: + bbox = bbox.reshape(-1, 2).astype(np.int) + img_show = cv2.drawContours(img_show, [bbox], -1, colors[str(int(classes[tag]))], 2) + bbox_str = [str(x) for x in bbox.reshape(-1)] + bbox_str = ','.join(bbox_str) + '\n' + fid_res.write(bbox_str) + tag+=1 + cv2.imwrite(os.path.join(save_path, 'result_img', image_name[i] + '.jpg'), img_show) + def InferImage(config): path = config['infer']['path'] save_path = config['infer']['save_path'] @@ -159,7 +191,10 @@ def InferImage(config): bar = tqdm(total=len(batch_imgs)) for i in range(len(batch_imgs)): bar.update(1) - InferOneImg(test_bin, batch_imgs[i],batch_img_names[i], save_path) + if 'n_class' in config['base'].keys(): + InferOneImgMul(test_bin, batch_imgs[i],batch_img_names[i], save_path,config['base']['n_class']) + else: + InferOneImg(test_bin, batch_imgs[i],batch_img_names[i], save_path) bar.close() else: diff --git a/tools/det_sast.py b/tools/det_sast.py index 82087a4..7bd8eb1 100644 --- a/tools/det_sast.py +++ b/tools/det_sast.py @@ -192,7 +192,7 @@ def TrainValProgram(config): log_write.set_names(title) for epoch in range(start_epoch,config['base']['n_epoch']): - train_dataset.gen_train_img() +# train_dataset.gen_train_img() model.train() optimizer_decay(config, optimizer, epoch) loss_write = ModelTrain(train_data_loader, model, criterion, optimizer, loss_bin, config, epoch) @@ -231,8 +231,8 @@ def TrainValProgram(config): if __name__ == "__main__": - stream = open('./config/det_SAST_resnet50_ori_dataload.yaml', 'r', encoding='utf-8') -# stream = open('./config/det_SAST_resnet50_3*3_ori_dataload.yaml', 'r', encoding='utf-8') +# stream = open('./config/det_SAST_resnet50_ori_dataload.yaml', 'r', encoding='utf-8') + stream = open('./config/det_SAST_resnet50_3_3_ori_dataload.yaml', 'r', encoding='utf-8') config = yaml.load(stream,Loader=yaml.FullLoader) TrainValProgram(config) \ No newline at end of file diff --git a/tools/det_train.py b/tools/det_train.py index d22ad67..d8ee729 100644 --- a/tools/det_train.py +++ b/tools/det_train.py @@ -165,7 +165,7 @@ def TrainValProgram(args): criterion = create_module(config['architectures']['loss_function'])(config) train_dataset = create_module(config['trainload']['function'])(config) test_dataset = create_module(config['testload']['function'])(config) - optimizer = create_module(config['optimizer']['function'])(config, model) + optimizer = create_module(config['optimizer']['function'])(config, model.parameters()) optimizer_decay = create_module(config['optimizer_decay']['function']) img_process = create_module(config['postprocess']['function'])(config) @@ -195,7 +195,7 @@ def TrainValProgram(args): use_distil = False if args.t_config is not None: use_distil = True - loss_bin = create_loss_bin(config['base']['algorithm'],use_distil) + loss_bin = create_loss_bin(config,use_distil) if torch.cuda.is_available(): if (len(config['base']['gpu_id'].split(',')) > 1): diff --git a/tools/rec_infer.py b/tools/rec_infer.py index c5cecac..31a505e 100644 --- a/tools/rec_infer.py +++ b/tools/rec_infer.py @@ -5,6 +5,7 @@ @time: 2020/08/20 """ import os +os.environ["CUDA_VISIBLE_DEVICES"] = '2' import sys sys.path.append('./') import cv2 @@ -14,8 +15,24 @@ import numpy as np from tqdm import tqdm import torchvision.transforms as transforms -from ptocr.utils.util_function import create_module,resize_image +from ptocr.utils.util_function import create_module,resize_image,resize_image_crnn from ptocr.utils.util_function import create_process_obj,create_dir,load_model +from get_acc_english import get_test_acc + +def is_chinese(check_str): + flag = False + for ch in check_str.decode('utf-8'): + if u'\u4e00' >= ch or ch >= u'\u9fff': + flag = True + return flag + +def is_number_letter(check_str): + if(check_str in '0123456789' or check_str in 'abcdefghijklmnopgrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'): + return True + else: + return False + + class TestProgram(): @@ -26,6 +43,7 @@ def __init__(self,config): config['base']['classes'] = len(self.converter.alphabet) model = create_module(config['architectures']['model_function'])(config) model = load_model(model,config['infer']['model_path']) + if torch.cuda.is_available(): model = model.cuda() self.model = model @@ -33,39 +51,96 @@ def __init__(self,config): self.model.eval() def infer_img(self,ori_img): - img = resize_image(ori_img,self.congig['base']['algorithm'],32) + +# img = resize_image(ori_img,self.congig['base']['algorithm'],32,4) + img = resize_image_crnn(ori_img) +# cv2.imwrite('result.jpg',img) + if(img.shape[0]!=self.congig['base']['img_shape'][0]): + return '' + img = Image.fromarray(img).convert('RGB') if(self.congig['base']['is_gray']): img = img.convert('L') +# img = img.resize((100, 32),Image.ANTIALIAS) + image_for_show = np.array(img.convert('RGB')).copy() img = transforms.ToTensor()(img) img.sub_(0.5).div_(0.5) img = img.unsqueeze(0) + if torch.cuda.is_available(): img = img.cuda() with torch.no_grad(): - preds = self.model(img) + preds,feau= self.model(img) + preds = preds.permute(1, 0, 2) + + time_step = preds.shape[0] + image_width = img.shape[3] + step_ = image_width//time_step + + ### 输出相似字结果 +# k_num = 3 +# p,idx = torch.softmax(preds,-1).squeeze().topk(k_num,-1) +# p = p[idx[:,0]>0].cpu().numpy() +# for i in range(k_num): +# index = idx[:,i][idx[:,0]>0] +# result = np.array(list(self.converter.alphabet))[index.cpu().numpy()-1].tolist() +# if(i==0): +# print('识别结果:',result,p[:,i]) +# else: +# print('相似字:',result,p[:,i]) + preds_size = torch.IntTensor([preds.size(0)]) _, preds = preds.max(2) preds = preds.squeeze(1) preds = preds.contiguous().view(-1) + sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False) +# raw_pred = self.converter.decode(preds.data, preds_size.data, raw=True) +# start_step = 0 +# word_po = [] +# for i in range(len(raw_pred)-1): +# if(raw_pred[i]!='-' and raw_pred[i]!=raw_pred[i+1]): +# # image_for_show = cv2.rectangle(image_for_show,(start_step+step_,0),(start_step+step_, img.shape[2]-1),(0,0,255)) +# start_step+=step_ +# word_po.append(start_step+step_) +# else: +# start_step+=step_ + +# # word_po = np.array(word_po) +# # word_po_flag = np.zeros(len(word_po)) +# # word_po_flag[1:] = word_po[:-1] +# # word_len = sorted(word_po-word_po_flag)[len(word_po)//2] +# word_s_e = [] +# for i in range(len(word_po)): +# if(is_number_letter(sim_preds[i])): +# word_len = self.congig['base']['img_shape'][0]//4 +# else: +# word_len = self.congig['base']['img_shape'][0]//2 +# word_s_e.append([word_po[i]-word_len,word_po[i]+word_len]) +# # image_for_show = cv2.rectangle(image_for_show,(int(word_po[i]-word_len),1),(int(word_po[i]+word_len), img.shape[2]-1),(255,0,255)) +# image_for_show = cv2.line(image_for_show,(int(word_po[i]),1),(int(word_po[i]), img.shape[2]-1),(255,0,255)) +# cv2.imwrite('result.jpg',image_for_show) + return sim_preds def InferImage(config): path = config['infer']['path'] save_path = config['infer']['save_path'] test_bin = TestProgram(config) + fid = open('re_result.txt','w+',encoding='utf-8') if os.path.isdir(path): files = os.listdir(path) bar = tqdm(total=len(files)) for file in files: + print(file) bar.update(1) image_name = file.split('.')[0] img_path = os.path.join(path,file) img = cv2.imread(img_path) rec_char = test_bin.infer_img(img) + fid.write(file+'\t'+rec_char+'\n') print(rec_char) bar.close() @@ -74,9 +149,51 @@ def InferImage(config): img = cv2.imread(path) rec_char = test_bin.infer_img(img) print(rec_char) + +def InferImageAcc(config): + root_path = config['infer']['path'] + save_path = config['infer']['save_path'] + test_bin = TestProgram(config) + acc_fid = open('acc.txt','w+',encoding='utf-8') + acc_all = 0 + for _dir in os.listdir(root_path): + path = os.path.join(root_path,_dir,'image') + gt_file = os.path.join(root_path,_dir,'val.txt') +# print(gt_file) + fid = open(_dir+'.txt','w+',encoding='utf-8') + if os.path.isdir(path): + files = os.listdir(path) + bar = tqdm(total=len(files)) + for file in files: +# print(file) + bar.update(1) + image_name = file.split('.')[0] + img_path = os.path.join(path,file) + img = cv2.imread(img_path) + rec_char = test_bin.infer_img(img) + fid.write(file+'\t'+rec_char+'\n') +# print(rec_char) + bar.close() + else: + image_name = path.split('/')[-1].split('.')[0] + img = cv2.imread(path) + rec_char = test_bin.infer_img(img) + print(rec_char) + fid.close() + acc = get_test_acc(_dir+'.txt',gt_file) + acc_all+=acc + acc_fid.write(_dir+':'+'\t'+str(acc)+'\n') + print('mean acc:',acc_all/9.0) + acc_fid.close() + if __name__ == "__main__": - stream = open('./config/rec_CRNN_ori.yaml', 'r', encoding='utf-8') + + stream = open('./config/rec_CRNN_mobilev3_small_english_all.yaml', 'r', encoding='utf-8') config = yaml.load(stream,Loader=yaml.FullLoader) - InferImage(config) \ No newline at end of file + InferImageAcc(config) + +# stream = open('./config/rec_CRNN_resnet_english.yaml', 'r', encoding='utf-8') +# config = yaml.load(stream,Loader=yaml.FullLoader) +# InferImage(config) \ No newline at end of file diff --git a/tools/rec_infer_bk1.py b/tools/rec_infer_bk1.py new file mode 100644 index 0000000..b1cd2a7 --- /dev/null +++ b/tools/rec_infer_bk1.py @@ -0,0 +1,159 @@ +#-*- coding:utf-8 _*- +""" +@author:fxw +@file: det_infer.py +@time: 2020/08/20 +""" +import os +os.environ["CUDA_VISIBLE_DEVICES"] = '2' +import sys +sys.path.append('./') +import cv2 +import torch +import yaml +from PIL import Image +import numpy as np +from tqdm import tqdm +import torchvision.transforms as transforms +from ptocr.utils.util_function import create_module,resize_image,resize_image_crnn +from ptocr.utils.util_function import create_process_obj,create_dir,load_model +from get_acc_english import get_test_acc + +def is_chinese(check_str): + flag = False + for ch in check_str.decode('utf-8'): + if u'\u4e00' >= ch or ch >= u'\u9fff': + flag = True + return flag + +def is_number_letter(check_str): + if(check_str in '0123456789' or check_str in 'abcdefghijklmnopgrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'): + return True + else: + return False + + + + +class TestProgram(): + def __init__(self,config): + super(TestProgram,self).__init__() + + self.converter = create_module(config['label_transform']['function'])(config) + model = create_module(config['architectures']['model_function'])(config) + model = load_model(model,config['infer']['model_path']) + + if torch.cuda.is_available(): + model = model.cuda() + self.model = model + self.congig = config + self.model.eval() + + def infer_img(self,ori_img): + +# img = resize_image(ori_img,self.congig['base']['algorithm'],32,4) + img = resize_image_crnn(ori_img) +# cv2.imwrite('result.jpg',img) + if(img.shape[0]!=self.congig['base']['img_shape'][0]): + return '' + + img = Image.fromarray(img).convert('RGB') + if(self.congig['base']['is_gray']): + img = img.convert('L') +# img = img.resize((100, 32),Image.ANTIALIAS) + image_for_show = np.array(img.convert('RGB')).copy() + img = transforms.ToTensor()(img) + img.sub_(0.5).div_(0.5) + img = img.unsqueeze(0) + + if torch.cuda.is_available(): + img = img.cuda() + + with torch.no_grad(): + preds,feau= self.model(img) + + time_step = preds.shape[0] + image_width = img.shape[3] + step_ = image_width//time_step + + + _, preds = preds.max(2) + + + sim_preds = self.converter.decode(preds.data) + print(preds) + print(sim_preds) + return sim_preds[0] + +def InferImage(config): + path = config['infer']['path'] + save_path = config['infer']['save_path'] + test_bin = TestProgram(config) + fid = open('re_result.txt','w+',encoding='utf-8') + if os.path.isdir(path): + files = os.listdir(path) + bar = tqdm(total=len(files)) + for file in files: + print(file) + bar.update(1) + image_name = file.split('.')[0] + img_path = os.path.join(path,file) + img = cv2.imread(img_path) + rec_char = test_bin.infer_img(img) + fid.write(file+'\t'+rec_char+'\n') + print(rec_char) + bar.close() + + else: + image_name = path.split('/')[-1].split('.')[0] + img = cv2.imread(path) + rec_char = test_bin.infer_img(img) + print(rec_char) + +def InferImageAcc(config): + root_path = config['infer']['path'] + save_path = config['infer']['save_path'] + test_bin = TestProgram(config) + acc_fid = open('acc.txt','w+',encoding='utf-8') + acc_all = 0 + for _dir in os.listdir(root_path): + path = os.path.join(root_path,_dir,'image') + gt_file = os.path.join(root_path,_dir,'val.txt') +# print(gt_file) + fid = open(_dir+'.txt','w+',encoding='utf-8') + if os.path.isdir(path): + files = os.listdir(path) + bar = tqdm(total=len(files)) + for file in files: +# print(file) + bar.update(1) + image_name = file.split('.')[0] + img_path = os.path.join(path,file) + img = cv2.imread(img_path) + rec_char = test_bin.infer_img(img) + fid.write(file+'\t'+rec_char+'\n') +# print(rec_char) + bar.close() + + else: + image_name = path.split('/')[-1].split('.')[0] + img = cv2.imread(path) + rec_char = test_bin.infer_img(img) + print(rec_char) + fid.close() + acc = get_test_acc(_dir+'.txt',gt_file) + acc_all+=acc + acc_fid.write(_dir+':'+'\t'+str(acc)+'\n') + print('mean acc:',acc_all/9.0) + acc_fid.close() + + +if __name__ == "__main__": + + stream = open('./config/rec_FC_resnet_english_all.yaml', 'r', encoding='utf-8') + config = yaml.load(stream,Loader=yaml.FullLoader) + InferImageAcc(config) + +# stream = open('./config/rec_CRNN_resnet_english.yaml', 'r', encoding='utf-8') +# config = yaml.load(stream,Loader=yaml.FullLoader) +# InferImage(config) \ No newline at end of file diff --git a/tools/rec_train.py b/tools/rec_train.py index 73c7c41..b11ea09 100644 --- a/tools/rec_train.py +++ b/tools/rec_train.py @@ -1,9 +1,3 @@ -# -*- coding:utf-8 _*- -""" -@author:fxw -@file: det_train.py -@time: 2020/08/07 -""" import sys sys.path.append('./') import cv2 @@ -21,173 +15,114 @@ from ptocr.utils.util_function import create_module, create_loss_bin, \ set_seed,save_checkpoint,create_dir from ptocr.utils.metrics import runningScore -from ptocr.utils.logger import Logger -from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter +from ptocr.utils.logger import Logger,TrainLog +from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter,restore_training from ptocr.dataloader.RecLoad.CRNNProcess import alignCollate import copy -GLOBAL_WORKER_ID = None +### 设置随机种子 GLOBAL_SEED = 2020 - - torch.manual_seed(GLOBAL_SEED) torch.cuda.manual_seed(GLOBAL_SEED) torch.cuda.manual_seed_all(GLOBAL_SEED) np.random.seed(GLOBAL_SEED) random.seed(GLOBAL_SEED) - - -def worker_init_fn(worker_id): - global GLOBAL_WORKER_ID - GLOBAL_WORKER_ID = worker_id - set_seed(GLOBAL_SEED + worker_id) - -# def backward_hook(self,grad_input, grad_output): -# for g in grad_input: -# g[g != g] = 0 # replace all nan/inf in gradients to zero - - -def ModelTrain(train_data_loader,LabelConverter,model,center_model, criterion, optimizer,center_criterion,optimizer_center,center_flag,loss_bin, config, epoch): - batch_time = AverageMeter() - end = time.time() - - for batch_idx, data in enumerate(train_data_loader): -# model.register_backward_hook(backward_hook) - if(data is None): - continue - imgs,labels = data +def ModelTrain(train_data_loader,LabelConverter, model,criterion,optimizer, train_log,loss_dict, config, epoch): + for batch_idx, ( imgs,labels) in enumerate(train_data_loader): pre_batch = {} gt_batch = {} if torch.cuda.is_available(): imgs = imgs.cuda() - preds = model(imgs) - + preds,feau = model(imgs) + preds = preds.permute(1, 0, 2) + + ######### labels,labels_len = LabelConverter.encode(labels,preds.size(0)) preds_size = Variable(torch.IntTensor([preds.size(0)] * config['trainload']['batch_size'])) pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size gt_batch['labels'],gt_batch['labels_len'] = labels,labels_len + ######### + + if config['loss']['use_ctc_weight']: + len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0 + len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0) + ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float() + ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0 + gt_batch['ctc_loss_weight'] = ctc_loss_weight + + loss = criterion(pre_batch, gt_batch).cuda() + optimizer.zero_grad() + loss.backward() + optimizer.step() - ctc_loss = criterion(pre_batch, gt_batch).cuda() metrics = {} - metrics['loss_total'] = 0.0 - metrics['loss_center'] = 0.0 - if center_criterion is not None and center_flag is True: - center_model.eval() - ##### - feautures = preds.clone() - with torch.no_grad(): - center_preds = center_model(imgs) - - center_preds = torch.softmax(center_preds,-1) - confs, center_preds = center_preds.max(2) - center_preds = center_preds.squeeze(1).transpose(1, 0).contiguous() - confs = confs.transpose(1, 0).contiguous() - -# confs = [] -# for i in range(center_preds.shape[0]): -# conf = [] -# for j in range(len(center_preds[i])): -# conf.append(probs[i,j,center_preds[i][j]]) -# confs.append(conf) -# confs = torch.Tensor(confs).cuda() - - b,t = center_preds.shape - feautures = feautures.transpose(1, 0).contiguous() - - confs = confs.view(-1) - center_preds = center_preds.view(-1) - feautures = feautures.view(b*t,-1) - - - index = (center_preds>0) & (confs>config['loss']['label_score']) - center_preds = center_preds[index] - feautures = feautures[index] - - center_loss = center_criterion(feautures,center_preds)*config['loss']['weight_center'] - - loss = ctc_loss + center_loss - - - metrics['loss_total'] = loss.item() - metrics['loss_ctc'] = ctc_loss.item() - metrics['loss_center'] = center_loss.item() - - ##### - optimizer_center.zero_grad() - optimizer.zero_grad() - - loss.backward() - optimizer.step() - - for param in center_criterion.parameters(): - param.grad.data *= (1. / config['loss']['weight_center']) - optimizer_center.step() - else: - loss = ctc_loss - metrics['loss_ctc'] = ctc_loss.item() - optimizer.zero_grad() - loss.backward() - optimizer.step() + metrics['ctc_loss'] = loss.item() + + for key in loss_dict.keys(): + loss_dict[key].update(metrics[key]) - for key in loss_bin.keys(): - loss_bin[key].loss_add(metrics[key]) - batch_time.update(time.time() - end) - end = time.time() if (batch_idx % config['base']['show_step'] == 0): log = '({}/{}/{}/{}) | ' \ .format(epoch, config['base']['n_epoch'], batch_idx, len(train_data_loader)) - bin_keys = list(loss_bin.keys()) - - for i in range(len(bin_keys)): - log += bin_keys[i] + ':{:.4f}'.format(loss_bin[bin_keys[i]].loss_mean()) + ' | ' - log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])+ ' | ' - log+='batch_time:{:.2f} s'.format(batch_time.avg)+ ' | ' - log+='total_time:{:.2f} min'.format(batch_time.avg * batch_idx / 60.0)+ ' | ' - log+='ETA:{:.2f} min'.format(batch_time.avg*(len(train_data_loader)-batch_idx)/60.0) - print(log) - loss_write = [] - for key in list(loss_bin.keys()): - loss_write.append(loss_bin[key].loss_mean()) - return loss_write,loss_bin['loss_ctc'].loss_mean() - + keys = list(loss_dict.keys()) + for i in range(len(keys)): + log += keys[i] + ':{:.4f}'.format(loss_dict[keys[i]].avg) + ' | ' + log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr']) + train_log.info(log) -def ModelEval(test_data_loader,LabelConverter,model,criterion,config): - bar = tqdm(total=len(test_data_loader)) - loss_avg = [] - n_correct = 0 - for batch_idx, (imgs, labels) in enumerate(test_data_loader): +def ModelEval(val_data_loader,LabelConverter, model,criterion,train_log,loss_dict,config): + + bar = tqdm(total=len(val_data_loader)) + val_loss = AverageMeter() + n_correct = AverageMeter() + + for batch_idx, (imgs, labels) in enumerate(val_data_loader): bar.update(1) + pre_batch = {} gt_batch = {} + if torch.cuda.is_available(): - imgs = imgs.cuda() + imgs = imgs.cuda() + with torch.no_grad(): - preds = model(imgs) + preds,feau = model(imgs) + preds = preds.permute(1, 0, 2) + labels_class, labels_len = LabelConverter.encode(labels,preds.size(0)) - preds_size = Variable(torch.IntTensor([preds.size(0)] * config['testload']['batch_size'])) + preds_size = Variable(torch.IntTensor([preds.size(0)] * config['valload']['batch_size'])) pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size gt_batch['labels'],gt_batch['labels_len'] = labels_class,labels_len + + if config['loss']['use_ctc_weight']: + len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0 + len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0) + ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float() + ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0 + gt_batch['ctc_loss_weight'] = ctc_loss_weight + cost = criterion(pre_batch, gt_batch) - loss_avg.append(cost.item()) + val_loss.update(cost.item()) + _, preds = preds.max(2) - preds = preds.squeeze(1) - preds = preds.transpose(1, 0).contiguous().view(-1) + preds = preds.squeeze(1).transpose(1, 0).contiguous().view(-1) + sim_preds = LabelConverter.decode(preds.data, preds_size.data, raw=False) for pred, target in zip(sim_preds, labels): if pred == target: - n_correct += 1 + n_correct.update(1) raw_preds = LabelConverter.decode(preds.data, preds_size.data, raw=True)[:config['base']['show_num']] for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels): - print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) + train_log.info('recog example %-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) - val_acc = n_correct / float(len(test_data_loader) * config['testload']['batch_size']) - val_loss = np.mean(loss_avg) - print('Test loss: %f, accuray: %f' % (val_loss, val_acc)) - return val_acc,val_loss + val_acc = n_correct.sum / float(len(val_data_loader) * config['valload']['batch_size']) + train_log.info('val loss: %f, val accuray: %f' % (val_loss.avg, val_acc)) + return val_acc -def TrainValProgram(config): + +def TrainValProgram(args): config = yaml.load(open(args.config, 'r', encoding='utf-8'),Loader=yaml.FullLoader) config = merge_config(config,args) @@ -195,8 +130,7 @@ def TrainValProgram(config): os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id'] create_dir(config['base']['checkpoints']) - checkpoints = os.path.join(config['base']['checkpoints'], - "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'], + checkpoints = os.path.join(config['base']['checkpoints'],"ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'], config['backbone']['function'].split(',')[-1], config['head']['function'].split(',')[-1], config['trainload']['batch_size'], @@ -208,40 +142,36 @@ def TrainValProgram(config): config['base']['classes'] = len(LabelConverter.alphabet) model = create_module(config['architectures']['model_function'])(config) criterion = create_module(config['architectures']['loss_function'])(config) - train_dataset = create_module(config['trainload']['function'])(config) - test_dataset = create_module(config['testload']['function'])(config) - optimizer = create_module(config['optimizer']['function'])(config, model) + train_dataset = create_module(config['trainload']['function'])(config,'train') + val_dataset = create_module(config['valload']['function'])(config,'val') + optimizer = create_module(config['optimizer']['function'])(config, model.parameters()) optimizer_decay = create_module(config['optimizer_decay']['function']) - - if config['loss']['use_center']: - center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['classes']) - optimizer_center = torch.optim.Adam(center_criterion.parameters(), lr= config['loss']['center_lr']) - optimizer_decay_center = create_module(config['optimizer_decay_center']['function']) - else: - center_criterion = None - optimizer_center=None - - + if os.path.exists(os.path.join(checkpoints,'train_log.txt')): + os.remove(os.path.join(checkpoints,'train_log.txt')) + train_log = TrainLog(os.path.join(checkpoints,'train_log.txt')) + train_log.info(model) + train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config['trainload']['batch_size'], shuffle=True, num_workers=config['trainload']['num_workers'], - worker_init_fn = worker_init_fn, collate_fn = alignCollate(), drop_last=True, pin_memory=True) - test_data_loader = torch.utils.data.DataLoader( - test_dataset, - batch_size=config['testload']['batch_size'], + val_data_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=config['valload']['batch_size'], shuffle=False, - num_workers=config['testload']['num_workers'], + num_workers=config['valload']['num_workers'], collate_fn = alignCollate(), drop_last=True, pin_memory=True) - - loss_bin = create_loss_bin(config['base']['algorithm'],use_center=config['loss']['use_center']) + + loss_dict = {} + for title in config['loss']['loss_title']: + loss_dict[title] = AverageMeter() if torch.cuda.is_available(): if (len(config['base']['gpu_id'].split(',')) > 1): @@ -254,51 +184,19 @@ def TrainValProgram(config): val_acc = 0 val_loss = 0 best_acc = 0 - - + if config['base']['restore']: - print('Resuming from checkpoint.') + train_log.info('Resuming from checkpoint.') assert os.path.isfile(config['base']['restore_file']), 'Error: no checkpoint file found!' - checkpoint = torch.load(config['base']['restore_file']) - start_epoch = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer']) - best_acc = checkpoint['best_acc'] - if not config['loss']['use_center']: - log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'], resume=True) - if config['loss']['use_center']: - if os.path.exists(os.path.join(checkpoints, 'log_center.txt')): - log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'], resume=True) - else: - log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm']) - title = list(loss_bin.keys()) - title.extend(['val_loss','test_acc','best_acc']) - log_write.set_names(title) - else: - print('Training from scratch.') - log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm']) - title = list(loss_bin.keys()) - title.extend(['val_loss','test_acc','best_acc']) - log_write.set_names(title) - center_flag = False - center_model = None - if config['base']['finetune']: - start_epoch = 0 - optimizer.param_groups[0]['lr'] = 0.0001 - center_flag = True - center_model = copy.deepcopy(model) + model,optimizer,start_epoch,best_acc = restore_training(config['base']['restore_file'],model,optimizer) + for epoch in range(start_epoch,config['base']['n_epoch']): model.train() optimizer_decay(config, optimizer, epoch) - if config['loss']['use_center']: - optimizer_decay_center(config, optimizer_center, epoch) - loss_write,loss_flag = ModelTrain(train_data_loader,LabelConverter, model,center_model, criterion, optimizer, center_criterion,optimizer_center,center_flag,loss_bin, config, epoch) -# if loss_flag < config['loss']['min_score']: -# center_flag = True + ModelTrain(train_data_loader,LabelConverter, model,criterion,optimizer, train_log,loss_dict, config, epoch) if(epoch >= config['base']['start_val']): model.eval() - val_acc,val_loss = ModelEval(test_data_loader,LabelConverter, model,criterion ,config) - print('val_acc:',val_acc,'val_loss',val_loss) + val_acc = ModelEval(val_data_loader,LabelConverter, model,criterion,train_log,loss_dict,config) if (val_acc > best_acc): save_checkpoint({ 'epoch': epoch + 1, @@ -308,13 +206,11 @@ def TrainValProgram(config): 'best_acc': val_acc }, checkpoints, config['base']['algorithm'] + '_best' + '.pth.tar') best_acc = val_acc - - - loss_write.extend([val_loss,val_acc,best_acc]) - log_write.append(loss_write) - for key in loss_bin.keys(): - loss_bin[key].loss_clear() - if epoch%config['base']['save_epoch'] ==0: + train_log.info('best_acc:' + str(best_acc)) + for key in loss_dict.keys(): + loss_dict[key].reset() + + if epoch % config['base']['save_epoch'] == 0: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), @@ -322,15 +218,12 @@ def TrainValProgram(config): 'optimizer': optimizer.state_dict(), 'best_acc': 0 },checkpoints,config['base']['algorithm']+'_'+str(epoch)+'.pth.tar') - + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Hyperparams') parser.add_argument('--config', help='config file path') parser.add_argument('--log_str', help='log title') args = parser.parse_args() - - TrainValProgram(args) \ No newline at end of file diff --git a/tools/rec_train_bk1.py b/tools/rec_train_bk1.py new file mode 100644 index 0000000..811d27c --- /dev/null +++ b/tools/rec_train_bk1.py @@ -0,0 +1,396 @@ +# -*- coding:utf-8 _*- +""" +@author:fxw +@file: det_train.py +@time: 2020/08/07 +""" +import sys +sys.path.append('./') +import cv2 +import torch +import time +import os +import argparse +import random +import numpy as np +from tqdm import tqdm +from torch.autograd import Variable +np.seterr(divide='ignore', invalid='ignore') +import yaml +import torch.utils.data +from ptocr.utils.util_function import create_module, create_loss_bin, \ +set_seed,save_checkpoint,create_dir +from ptocr.utils.metrics import runningScore +from ptocr.utils.logger import Logger +from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter +from ptocr.dataloader.RecLoad.CRNNProcess import alignCollate +import copy +import re + +GLOBAL_WORKER_ID = None +GLOBAL_SEED = 2020 + + +torch.manual_seed(GLOBAL_SEED) +torch.cuda.manual_seed(GLOBAL_SEED) +torch.cuda.manual_seed_all(GLOBAL_SEED) +np.random.seed(GLOBAL_SEED) +random.seed(GLOBAL_SEED) + + + +def worker_init_fn(worker_id): + global GLOBAL_WORKER_ID + GLOBAL_WORKER_ID = worker_id + set_seed(GLOBAL_SEED + worker_id) + +# def backward_hook(self,grad_input, grad_output): +# for g in grad_input: +# g[g != g] = 0 # replace all nan/inf in gradients to zero + + +def ModelEval(test_data_loaders,LabelConverter,model,criterion,config): + loss_all= [] + acc_all = [] + for test_data_loader in test_data_loaders: + bar = tqdm(total=len(test_data_loader)) + loss_avg = [] + n_correct = 0 + for batch_idx, (imgs, labels) in enumerate(test_data_loader): + bar.update(1) + pre_batch = {} + gt_batch = {} + if torch.cuda.is_available(): + imgs = imgs.cuda() + with torch.no_grad(): + preds,feau = model(imgs) + preds = preds.permute(1, 0, 2) + + labels_class, labels_len = LabelConverter.encode(labels,preds.size(0)) + preds_size = Variable(torch.IntTensor([preds.size(0)] * config['valload']['batch_size'])) + pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size + gt_batch['labels'],gt_batch['labels_len'] = labels_class,labels_len + + if config['loss']['use_ctc_weight']: + len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0 + len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0) + ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float() + ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0 + gt_batch['ctc_loss_weight'] = ctc_loss_weight + + cost = criterion(pre_batch, gt_batch) + loss_avg.append(cost.item()) + _, preds = preds.max(2) + preds = preds.squeeze(1) + preds = preds.contiguous().view(-1) + sim_preds = LabelConverter.decode(preds.data, preds_size.data, raw=False) + for pred, target in zip([sim_preds], labels): + target = ''.join(re.findall('[0-9a-zA-Z]+',target)).lower() + if pred == target: + n_correct += 1 + val_acc = n_correct / float(len(test_data_loader) * config['valload']['batch_size']) + val_loss = np.mean(loss_avg) + loss_all.append(val_loss) + acc_all.append(val_acc) + print('Test acc:' ,acc_all) + return loss_all,acc_all + +def ModelTrain(train_data_loader,test_data_loader,LabelConverter,model,center_model, criterion, optimizer,center_criterion,optimizer_center,center_flag,loss_bin, config,optimizer_decay,optimizer_decay_center,log_write,checkpoints): + batch_time = AverageMeter() + end = time.time() + best_acc = 0 + dataloader_bin = [] + fid = open('test.txt','w+') + for i in range(len(train_data_loader)): + dataloader_bin.append(enumerate(train_data_loader[i])) + + for iters in range(config['base']['start_iters'],config['base']['max_iters']): + model.train() + optimizer_decay(config, optimizer, iters) + if config['loss']['use_center']: + optimizer_decay_center(config, optimizer_center, iters) + + imgs = [] + labels = [] + try: + for i in range(len(train_data_loader)): + index,(img,label) = next(dataloader_bin[i]) + imgs.append(img) + labels.extend(label) + imgs = torch.cat(imgs,0) + except: + for i in range(len(train_data_loader)): + dataloader_bin[i] = enumerate(train_data_loader[i]) + continue + + pre_batch = {} + gt_batch = {} + + if torch.cuda.is_available(): + imgs = imgs.cuda() + preds,feau = model(imgs) + + preds = preds.permute(1, 0, 2) + + labels,labels_len = LabelConverter.encode(labels,preds.size(0)) + preds_size = Variable(torch.IntTensor([preds.size(0)] * config['trainload']['batch_size'])) + pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size + gt_batch['labels'],gt_batch['labels_len'] = labels,labels_len + ######### + +# import pdb +# pdb.set_trace() + + if config['loss']['use_ctc_weight']: +# print('use') + + len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0 + len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0) + ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float() + + ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0 + gt_batch['ctc_loss_weight'] = ctc_loss_weight + + ctc_loss = criterion(pre_batch, gt_batch).cuda() + metrics = {} + metrics['loss_total'] = 0.0 + metrics['loss_center'] = 0.0 + if center_criterion is not None and center_flag is True: + center_model.eval() + ##### + feautures = preds.clone() + with torch.no_grad(): + center_preds,center_feau = center_model(imgs) + center_preds = center_preds.permute(1, 0, 2) + + center_preds = torch.softmax(center_preds,-1) + confs, center_preds = center_preds.max(2) + center_preds = center_preds.squeeze(1).transpose(1, 0).contiguous() + confs = confs.transpose(1, 0).contiguous() + +# confs = [] +# for i in range(center_preds.shape[0]): +# conf = [] +# for j in range(len(center_preds[i])): +# conf.append(probs[i,j,center_preds[i][j]]) +# confs.append(conf) +# confs = torch.Tensor(confs).cuda() + + b,t = center_preds.shape + +# feautures = feautures.transpose(1, 0).contiguous() + feautures = center_feau[0].transpose(1, 0).contiguous() + +# import pdb +# pdb.set_trace() + ### 去重复 + repeat_index = (center_preds[:,:-1] == center_preds[:,1:]) + center_preds[:,:-1][repeat_index] = 0 + + confs = confs.view(-1) + center_preds = center_preds.view(-1) + feautures = feautures.view(b*t,-1) + + + index = (center_preds>0) & (confs>config['loss']['label_score']) + center_preds = center_preds[index] + feautures = feautures[index] + + center_loss = center_criterion(feautures,center_preds)*config['loss']['weight_center'] + + loss = ctc_loss + center_loss + + + metrics['loss_total'] = loss.item() + metrics['loss_ctc'] = ctc_loss.item() + metrics['loss_center'] = center_loss.item() + + ##### + optimizer_center.zero_grad() + optimizer.zero_grad() + + loss.backward() + optimizer.step() + + for param in center_criterion.parameters(): + param.grad.data *= (1. / config['loss']['weight_center']) + optimizer_center.step() + else: + loss = ctc_loss + metrics['loss_ctc'] = ctc_loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for key in loss_bin.keys(): + loss_bin[key].loss_add(metrics[key]) + batch_time.update(time.time() - end) + end = time.time() + if (iters % config['base']['show_step'] == 0): + log = '({}/{}) | ' \ + .format(iters,config['base']['max_iters']) + bin_keys = list(loss_bin.keys()) + + for i in range(len(bin_keys)): + log += bin_keys[i] + ':{:.4f}'.format(loss_bin[bin_keys[i]].loss_mean()) + ' | ' + log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])+ ' | ' + log+='batch_time:{:.2f} s'.format(batch_time.avg)+ ' | ' + log+='total_time:{:.2f} min'.format(batch_time.avg * iters / 60.0)+ ' | ' + log+='ETA:{:.2f} min'.format(batch_time.avg*(config['base']['max_iters']-iters)/60.0) + print(log) + + + if(iters % config['base']['eval_iter']==0 and iters!=0): + + loss_write = [] + for key in list(loss_bin.keys()): + loss_write.append(loss_bin[key].loss_mean()) + + model.eval() + val_loss,acc_all = ModelEval(test_data_loader,LabelConverter, model,criterion ,config) + val_acc = np.mean(acc_all) + if (val_acc > best_acc): + save_checkpoint({ + 'iters': iters + 1, + 'state_dict': model.state_dict(), + 'lr': config['optimizer']['base_lr'], + 'optimizer': optimizer.state_dict(), + 'best_acc': val_acc + }, checkpoints, config['base']['algorithm'] + '_best' + '.pth.tar') + best_acc = val_acc + acc_all.append(val_acc) + acc_all = [str(x) for x in acc_all] + fid.write(str(','.join(acc_all))+'\n') + fid.flush() + loss_write.extend([0,0,0]) + log_write.append(loss_write) + for key in loss_bin.keys(): + loss_bin[key].loss_clear() + if iters %config['base']['eval_iter'] ==0: + save_checkpoint({ + 'iters': iters + 1, + 'state_dict': model.state_dict(), + 'lr': config['optimizer']['base_lr'], + 'optimizer': optimizer.state_dict(), + 'best_acc': 0 + },checkpoints,config['base']['algorithm']+'_'+str(iters)+'.pth.tar') + + + + + + + + + +def TrainValProgram(config): + + config = yaml.load(open(args.config, 'r', encoding='utf-8'),Loader=yaml.FullLoader) + config = merge_config(config,args) + + os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id'] + + create_dir(config['base']['checkpoints']) + checkpoints = os.path.join(config['base']['checkpoints'], + "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'], + config['backbone']['function'].split(',')[-1], + config['head']['function'].split(',')[-1], + config['trainload']['batch_size'], + config['base']['max_iters'], + args.log_str)) + create_dir(checkpoints) + + LabelConverter = create_module(config['label_transform']['function'])(config) + config['base']['classes'] = len(LabelConverter.alphabet) + model = create_module(config['architectures']['model_function'])(config) + criterion = create_module(config['architectures']['loss_function'])(config) + train_data_loader = create_module(config['trainload']['function'])(config) + test_data_loader = create_module(config['valload']['function'])(config) + optimizer = create_module(config['optimizer']['function'])(config, model) + optimizer_decay = create_module(config['optimizer_decay']['function']) + + if config['loss']['use_center']: +# center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['classes']) + center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['hiddenchannel']) + + optimizer_center = torch.optim.Adam(center_criterion.parameters(), lr= config['loss']['center_lr']) + optimizer_decay_center = create_module(config['optimizer_decay_center']['function']) + else: + center_criterion = None + optimizer_center=None + optimizer_decay_center = None + + + + + loss_bin = create_loss_bin(config['base']['algorithm'],use_center=config['loss']['use_center']) + + if torch.cuda.is_available(): + if (len(config['base']['gpu_id'].split(',')) > 1): + model = torch.nn.DataParallel(model).cuda() + else: + model = model.cuda() + criterion = criterion.cuda() + + + + print(model) + + ### model.head.lstm_2.embedding = nn.Linear(in_features=512, out_features=2000, bias=True) + + if config['base']['restore']: + print('Resuming from checkpoint.') + assert os.path.isfile(config['base']['restore_file']), 'Error: no checkpoint file found!' + checkpoint = torch.load(config['base']['restore_file']) + start_epoch = checkpoint['epoch'] +# model.load_state_dict(checkpoint['state_dict']) + try: + model.load_state_dict(checkpoint['state_dict']) + except: + state = model.state_dict() + for key in state.keys(): + state[key] = checkpoint['state_dict'][key[7:]] + model.load_state_dict(state) + + optimizer.load_state_dict(checkpoint['optimizer']) + best_acc = checkpoint['best_acc'] + if not config['loss']['use_center']: + log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'], resume=True) + if config['loss']['use_center']: + if os.path.exists(os.path.join(checkpoints, 'log_center.txt')): + log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'], resume=True) + else: + log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm']) + title = list(loss_bin.keys()) + title.extend(['val_loss','test_acc','best_acc']) + log_write.set_names(title) + else: + print('Training from scratch.') + log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm']) + title = list(loss_bin.keys()) + title.extend(['val_loss','test_acc','best_acc']) + log_write.set_names(title) + center_flag = False + center_model = None + if config['base']['finetune']: + start_epoch = 0 + optimizer.param_groups[0]['lr'] = 0.0001 + center_flag = True + center_model = copy.deepcopy(model) + + + loss_write,loss_flag = ModelTrain(train_data_loader,test_data_loader,LabelConverter, model,center_model, criterion, optimizer, center_criterion,optimizer_center,center_flag,loss_bin, config,optimizer_decay,optimizer_decay_center,log_write,checkpoints) + + + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser(description='Hyperparams') + parser.add_argument('--config', help='config file path') + parser.add_argument('--log_str', help='log title') + args = parser.parse_args() + + + TrainValProgram(args) \ No newline at end of file diff --git a/tools/rec_train_bk2.py b/tools/rec_train_bk2.py new file mode 100644 index 0000000..ce6a9eb --- /dev/null +++ b/tools/rec_train_bk2.py @@ -0,0 +1,332 @@ +# -*- coding:utf-8 _*- +""" +@author:fxw +@file: det_train.py +@time: 2020/08/07 +""" +import sys +sys.path.append('./') +import cv2 +import torch +import time +import os +import argparse +import random +import numpy as np +from tqdm import tqdm +from torch.autograd import Variable +np.seterr(divide='ignore', invalid='ignore') +import yaml +import torch.utils.data +from ptocr.utils.util_function import create_module, create_loss_bin, \ +set_seed,save_checkpoint,create_dir +from ptocr.utils.metrics import runningScore +from ptocr.utils.logger import Logger +from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter +from ptocr.dataloader.RecLoad.CRNNProcess import alignCollate +import copy +import re + +GLOBAL_WORKER_ID = None +GLOBAL_SEED = 2020 + + +torch.manual_seed(GLOBAL_SEED) +torch.cuda.manual_seed(GLOBAL_SEED) +torch.cuda.manual_seed_all(GLOBAL_SEED) +np.random.seed(GLOBAL_SEED) +random.seed(GLOBAL_SEED) + + + +def worker_init_fn(worker_id): + global GLOBAL_WORKER_ID + GLOBAL_WORKER_ID = worker_id + set_seed(GLOBAL_SEED + worker_id) + +# def backward_hook(self,grad_input, grad_output): +# for g in grad_input: +# g[g != g] = 0 # replace all nan/inf in gradients to zero + + +def ModelEval(test_data_loaders,LabelConverter,model,criterion,config): + loss_all= [] + acc_all = [] + for test_data_loader in test_data_loaders: + bar = tqdm(total=len(test_data_loader)) + loss_avg = [] + n_correct = 0 + for batch_idx, (imgs, labels) in enumerate(test_data_loader): + bar.update(1) + pre_batch = {} + gt_batch = {} + if torch.cuda.is_available(): + imgs = imgs.cuda() + with torch.no_grad(): + preds,feau = model(imgs) + + + _,_, labels_class = LabelConverter.test_encode(labels) + + pre_batch['pred'],gt_batch['gt'] = preds,labels_class + + + if config['loss']['use_ctc_weight']: + len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0 + len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0) + ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float() + ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0 + gt_batch['ctc_loss_weight'] = ctc_loss_weight + + cost,_ = criterion(pre_batch, gt_batch) + loss_avg.append(cost.item()) + _, preds = preds.max(2) + + sim_preds = LabelConverter.decode(preds.data) + for pred, target in zip(sim_preds, labels): + target = ''.join(re.findall('[0-9a-zA-Z]+',target)).lower() + if pred == target: + n_correct += 1 + val_acc = n_correct / float(len(test_data_loader) * config['valload']['batch_size']) + val_loss = np.mean(loss_avg) + loss_all.append(val_loss) + acc_all.append(val_acc) + print('Test acc:' ,acc_all) + return loss_all,acc_all + +def ModelTrain(train_data_loader,test_data_loader,LabelConverter,model,center_model, criterion, optimizer,center_criterion,optimizer_center,center_flag,loss_bin, config,optimizer_decay,optimizer_decay_center,log_write,checkpoints): + batch_time = AverageMeter() + end = time.time() + best_acc = 0 + dataloader_bin = [] + fid = open('test.txt','w+') + for i in range(len(train_data_loader)): + dataloader_bin.append(enumerate(train_data_loader[i])) + + for iters in range(config['base']['start_iters'],config['base']['max_iters']): + model.train() + optimizer_decay(config, optimizer, iters) + if config['loss']['use_center']: + optimizer_decay_center(config, optimizer_center, iters) + + imgs = [] + labels = [] + try: + for i in range(len(train_data_loader)): + index,(img,label) = next(dataloader_bin[i]) + imgs.append(img) + labels.extend(label) + imgs = torch.cat(imgs,0) + except: + for i in range(len(train_data_loader)): + dataloader_bin[i] = enumerate(train_data_loader[i]) + continue + + pre_batch = {} + gt_batch = {} + + _, _, labels = LabelConverter.train_encode(labels) + + if torch.cuda.is_available(): + imgs = imgs.cuda() + preds,feau = model(imgs) + + + pre_batch['pred'],gt_batch['gt'] = preds,labels.long().cuda() + ######### + + + + if config['loss']['use_ctc_weight']: +# print('use') + + len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0 + len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0) + ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float() + + ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0 + gt_batch['ctc_loss_weight'] = ctc_loss_weight + + loss,_ = criterion(pre_batch, gt_batch) + +# import pdb +# pdb.set_trace() + + metrics = {} + metrics['loss_fc'] = loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for key in loss_bin.keys(): + loss_bin[key].loss_add(metrics[key]) + batch_time.update(time.time() - end) + end = time.time() + if (iters % config['base']['show_step'] == 0): + log = '({}/{}) | ' \ + .format(iters,config['base']['max_iters']) + bin_keys = list(loss_bin.keys()) + + for i in range(len(bin_keys)): + log += bin_keys[i] + ':{:.4f}'.format(loss_bin[bin_keys[i]].loss_mean()) + ' | ' + log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])+ ' | ' + log+='batch_time:{:.2f} s'.format(batch_time.avg)+ ' | ' + log+='total_time:{:.2f} min'.format(batch_time.avg * iters / 60.0)+ ' | ' + log+='ETA:{:.2f} min'.format(batch_time.avg*(config['base']['max_iters']-iters)/60.0) + print(log) + + + if(iters % config['base']['eval_iter']==0 and iters!=0): + + loss_write = [] + for key in list(loss_bin.keys()): + loss_write.append(loss_bin[key].loss_mean()) + + model.eval() + val_loss,acc_all = ModelEval(test_data_loader,LabelConverter, model,criterion ,config) + val_acc = np.mean(acc_all) + if (val_acc > best_acc): + save_checkpoint({ + 'iters': iters + 1, + 'state_dict': model.state_dict(), + 'lr': config['optimizer']['base_lr'], + 'optimizer': optimizer.state_dict(), + 'best_acc': val_acc + }, checkpoints, config['base']['algorithm'] + '_best' + '.pth.tar') + best_acc = val_acc + acc_all.append(val_acc) + acc_all = [str(x) for x in acc_all] + fid.write(str(','.join(acc_all))+'\n') + fid.flush() + loss_write.extend([0,0,0]) + log_write.append(loss_write) + for key in loss_bin.keys(): + loss_bin[key].loss_clear() + if iters %config['base']['eval_iter'] ==0: + save_checkpoint({ + 'iters': iters + 1, + 'state_dict': model.state_dict(), + 'lr': config['optimizer']['base_lr'], + 'optimizer': optimizer.state_dict(), + 'best_acc': 0 + },checkpoints,config['base']['algorithm']+'_'+str(iters)+'.pth.tar') + + + + + + + + + +def TrainValProgram(config): + + config = yaml.load(open(args.config, 'r', encoding='utf-8'),Loader=yaml.FullLoader) + config = merge_config(config,args) + + os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id'] + + create_dir(config['base']['checkpoints']) + checkpoints = os.path.join(config['base']['checkpoints'], + "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'], + config['backbone']['function'].split(',')[-1], + config['head']['function'].split(',')[-1], + config['trainload']['batch_size'], + config['base']['max_iters'], + args.log_str)) + create_dir(checkpoints) + + LabelConverter = create_module(config['label_transform']['function'])(config) + + model = create_module(config['architectures']['model_function'])(config) + criterion = create_module(config['architectures']['loss_function'])(config) + train_data_loader = create_module(config['trainload']['function'])(config) + test_data_loader = create_module(config['valload']['function'])(config) + optimizer = create_module(config['optimizer']['function'])(config, model) + optimizer_decay = create_module(config['optimizer_decay']['function']) + + if config['loss']['use_center']: +# center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['classes']) + center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['hiddenchannel']) + + optimizer_center = torch.optim.Adam(center_criterion.parameters(), lr= config['loss']['center_lr']) + optimizer_decay_center = create_module(config['optimizer_decay_center']['function']) + else: + center_criterion = None + optimizer_center=None + optimizer_decay_center = None + + + + + loss_bin = create_loss_bin(config['base']['algorithm'],use_center=config['loss']['use_center']) + + if torch.cuda.is_available(): + if (len(config['base']['gpu_id'].split(',')) > 1): + model = torch.nn.DataParallel(model).cuda() + else: + model = model.cuda() + criterion = criterion.cuda() + + + + print(model) + + ### model.head.lstm_2.embedding = nn.Linear(in_features=512, out_features=2000, bias=True) + + if config['base']['restore']: + print('Resuming from checkpoint.') + assert os.path.isfile(config['base']['restore_file']), 'Error: no checkpoint file found!' + checkpoint = torch.load(config['base']['restore_file']) + start_epoch = checkpoint['epoch'] +# model.load_state_dict(checkpoint['state_dict']) + try: + model.load_state_dict(checkpoint['state_dict']) + except: + state = model.state_dict() + for key in state.keys(): + state[key] = checkpoint['state_dict'][key[7:]] + model.load_state_dict(state) + + optimizer.load_state_dict(checkpoint['optimizer']) + best_acc = checkpoint['best_acc'] + if not config['loss']['use_center']: + log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'], resume=True) + if config['loss']['use_center']: + if os.path.exists(os.path.join(checkpoints, 'log_center.txt')): + log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'], resume=True) + else: + log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm']) + title = list(loss_bin.keys()) + title.extend(['val_loss','test_acc','best_acc']) + log_write.set_names(title) + else: + print('Training from scratch.') + log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm']) + title = list(loss_bin.keys()) + title.extend(['val_loss','test_acc','best_acc']) + log_write.set_names(title) + center_flag = False + center_model = None + if config['base']['finetune']: + start_epoch = 0 + optimizer.param_groups[0]['lr'] = 0.0001 + center_flag = True + center_model = copy.deepcopy(model) + + + loss_write,loss_flag = ModelTrain(train_data_loader,test_data_loader,LabelConverter, model,center_model, criterion, optimizer, center_criterion,optimizer_center,center_flag,loss_bin, config,optimizer_decay,optimizer_decay_center,log_write,checkpoints) + + + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser(description='Hyperparams') + parser.add_argument('--config', help='config file path') + parser.add_argument('--log_str', help='log title') + args = parser.parse_args() + + + TrainValProgram(args) \ No newline at end of file diff --git a/tools/rec_train_bk3.py b/tools/rec_train_bk3.py new file mode 100644 index 0000000..67d6c11 --- /dev/null +++ b/tools/rec_train_bk3.py @@ -0,0 +1,376 @@ +# -*- coding:utf-8 _*- +""" +@author:fxw +@file: det_train.py +@time: 2020/08/07 +""" +import sys +sys.path.append('./') +import cv2 +import torch +import time +import os +import argparse +import random +import numpy as np +from tqdm import tqdm +from torch.autograd import Variable +np.seterr(divide='ignore', invalid='ignore') +import yaml +import torch.utils.data +from ptocr.utils.util_function import create_module, create_loss_bin, \ +set_seed,save_checkpoint,create_dir +from ptocr.utils.metrics import runningScore +from ptocr.utils.logger import Logger +from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter +from ptocr.dataloader.RecLoad.CRNNProcess import alignCollate +import copy + +GLOBAL_WORKER_ID = None +GLOBAL_SEED = 2020 + + +torch.manual_seed(GLOBAL_SEED) +torch.cuda.manual_seed(GLOBAL_SEED) +torch.cuda.manual_seed_all(GLOBAL_SEED) +np.random.seed(GLOBAL_SEED) +random.seed(GLOBAL_SEED) + + + +def worker_init_fn(worker_id): + global GLOBAL_WORKER_ID + GLOBAL_WORKER_ID = worker_id + set_seed(GLOBAL_SEED + worker_id) + +# def backward_hook(self,grad_input, grad_output): +# for g in grad_input: +# g[g != g] = 0 # replace all nan/inf in gradients to zero + + +def ModelTrain(train_data_loader,LabelConverter,model,center_model, criterion, optimizer,center_criterion,optimizer_center,center_flag,loss_bin, config, epoch): + batch_time = AverageMeter() + end = time.time() + for batch_idx, data in enumerate(train_data_loader): +# model.register_backward_hook(backward_hook) + if(data is None): + continue + imgs,labels = data + pre_batch = {} + gt_batch = {} + + if torch.cuda.is_available(): + imgs = imgs.cuda() + preds,feau = model(imgs) + + preds = preds.permute(1, 0, 2) + + labels,labels_len = LabelConverter.encode(labels,preds.size(0)) + preds_size = Variable(torch.IntTensor([preds.size(0)] * config['trainload']['batch_size'])) + pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size + gt_batch['labels'],gt_batch['labels_len'] = labels,labels_len + ######### + if config['loss']['use_ctc_weight']: + + len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0 + len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0) + ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float() + + ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0 + gt_batch['ctc_loss_weight'] = ctc_loss_weight + + ctc_loss = criterion(pre_batch, gt_batch).cuda() + metrics = {} + metrics['loss_total'] = 0.0 + metrics['loss_center'] = 0.0 + if center_criterion is not None and center_flag is True: + center_model.eval() + ##### + feautures = preds.clone() + with torch.no_grad(): + center_preds,center_feau = center_model(imgs) + center_preds = center_preds.permute(1, 0, 2) + + center_preds = torch.softmax(center_preds,-1) + confs, center_preds = center_preds.max(2) + center_preds = center_preds.squeeze(1).transpose(1, 0).contiguous() + confs = confs.transpose(1, 0).contiguous() + +# confs = [] +# for i in range(center_preds.shape[0]): +# conf = [] +# for j in range(len(center_preds[i])): +# conf.append(probs[i,j,center_preds[i][j]]) +# confs.append(conf) +# confs = torch.Tensor(confs).cuda() + + b,t = center_preds.shape + +# feautures = feautures.transpose(1, 0).contiguous() + feautures = center_feau[0].transpose(1, 0).contiguous() + +# import pdb +# pdb.set_trace() + ### 去重复 + repeat_index = (center_preds[:,:-1] == center_preds[:,1:]) + center_preds[:,:-1][repeat_index] = 0 + + confs = confs.view(-1) + center_preds = center_preds.view(-1) + feautures = feautures.view(b*t,-1) + + + index = (center_preds>0) & (confs>config['loss']['label_score']) + center_preds = center_preds[index] + feautures = feautures[index] + + center_loss = center_criterion(feautures,center_preds)*config['loss']['weight_center'] + + loss = ctc_loss + center_loss + + + metrics['loss_total'] = loss.item() + metrics['loss_ctc'] = ctc_loss.item() + metrics['loss_center'] = center_loss.item() + + ##### + optimizer_center.zero_grad() + optimizer.zero_grad() + + loss.backward() + optimizer.step() + + for param in center_criterion.parameters(): + param.grad.data *= (1. / config['loss']['weight_center']) + optimizer_center.step() + else: + loss = ctc_loss + metrics['loss_ctc'] = ctc_loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for key in loss_bin.keys(): + loss_bin[key].loss_add(metrics[key]) + batch_time.update(time.time() - end) + end = time.time() + if (batch_idx % config['base']['show_step'] == 0): + log = '({}/{}/{}/{}) | ' \ + .format(epoch, config['base']['n_epoch'], batch_idx, len(train_data_loader)) + bin_keys = list(loss_bin.keys()) + + for i in range(len(bin_keys)): + log += bin_keys[i] + ':{:.4f}'.format(loss_bin[bin_keys[i]].loss_mean()) + ' | ' + log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])+ ' | ' + log+='batch_time:{:.2f} s'.format(batch_time.avg)+ ' | ' + log+='total_time:{:.2f} min'.format(batch_time.avg * batch_idx / 60.0)+ ' | ' + log+='ETA:{:.2f} min'.format(batch_time.avg*(len(train_data_loader)-batch_idx)/60.0) + print(log) + loss_write = [] + for key in list(loss_bin.keys()): + loss_write.append(loss_bin[key].loss_mean()) + return loss_write,loss_bin['loss_ctc'].loss_mean() + + +def ModelEval(test_data_loader,LabelConverter,model,criterion,config): + bar = tqdm(total=len(test_data_loader)) + loss_avg = [] + n_correct = 0 + for batch_idx, (imgs, labels) in enumerate(test_data_loader): + bar.update(1) + pre_batch = {} + gt_batch = {} + if torch.cuda.is_available(): + imgs = imgs.cuda() + with torch.no_grad(): + preds,feau = model(imgs) + preds = preds.permute(1, 0, 2) + + labels_class, labels_len = LabelConverter.encode(labels,preds.size(0)) + preds_size = Variable(torch.IntTensor([preds.size(0)] * config['testload']['batch_size'])) + pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size + gt_batch['labels'],gt_batch['labels_len'] = labels_class,labels_len + + if config['loss']['use_ctc_weight']: + len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0 + len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0) + ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float() + ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0 + gt_batch['ctc_loss_weight'] = ctc_loss_weight + + cost = criterion(pre_batch, gt_batch) + loss_avg.append(cost.item()) + _, preds = preds.max(2) + preds = preds.squeeze(1) + preds = preds.transpose(1, 0).contiguous().view(-1) + sim_preds = LabelConverter.decode(preds.data, preds_size.data, raw=False) + for pred, target in zip(sim_preds, labels): + if pred == target: + n_correct += 1 + raw_preds = LabelConverter.decode(preds.data, preds_size.data, raw=True)[:config['base']['show_num']] + for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels): + print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) + + val_acc = n_correct / float(len(test_data_loader) * config['testload']['batch_size']) + val_loss = np.mean(loss_avg) + print('Test loss: %f, accuray: %f' % (val_loss, val_acc)) + return val_acc,val_loss + +def TrainValProgram(config): + + config = yaml.load(open(args.config, 'r', encoding='utf-8'),Loader=yaml.FullLoader) + config = merge_config(config,args) + + os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id'] + + create_dir(config['base']['checkpoints']) + checkpoints = os.path.join(config['base']['checkpoints'], + "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'], + config['backbone']['function'].split(',')[-1], + config['head']['function'].split(',')[-1], + config['trainload']['batch_size'], + config['base']['n_epoch'], + args.log_str)) + create_dir(checkpoints) + + LabelConverter = create_module(config['label_transform']['function'])(config) + config['base']['classes'] = len(LabelConverter.alphabet) + model = create_module(config['architectures']['model_function'])(config) + criterion = create_module(config['architectures']['loss_function'])(config) + train_dataset = create_module(config['trainload']['function'])(config) + test_dataset = create_module(config['testload']['function'])(config) + optimizer = create_module(config['optimizer']['function'])(config, model) + optimizer_decay = create_module(config['optimizer_decay']['function']) + + if config['loss']['use_center']: +# center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['classes']) + center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['hiddenchannel']) + + optimizer_center = torch.optim.Adam(center_criterion.parameters(), lr= config['loss']['center_lr']) + optimizer_decay_center = create_module(config['optimizer_decay_center']['function']) + else: + center_criterion = None + optimizer_center=None + + + train_data_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=config['trainload']['batch_size'], + shuffle=True, + num_workers=config['trainload']['num_workers'], + worker_init_fn = worker_init_fn, + collate_fn = alignCollate(), + drop_last=True, + pin_memory=True) + + test_data_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=config['testload']['batch_size'], + shuffle=False, + num_workers=config['testload']['num_workers'], + collate_fn = alignCollate(), + drop_last=True, + pin_memory=True) + + loss_bin = create_loss_bin(config['base']['algorithm'],use_center=config['loss']['use_center']) + + if torch.cuda.is_available(): + if (len(config['base']['gpu_id'].split(',')) > 1): + model = torch.nn.DataParallel(model).cuda() + else: + model = model.cuda() + criterion = criterion.cuda() + + start_epoch = 0 + val_acc = 0 + val_loss = 0 + best_acc = 0 + + print(model) + + if config['base']['restore']: + print('Resuming from checkpoint.') + assert os.path.isfile(config['base']['restore_file']), 'Error: no checkpoint file found!' + checkpoint = torch.load(config['base']['restore_file']) + start_epoch = checkpoint['epoch'] +# model.load_state_dict(checkpoint['state_dict']) + try: + model.load_state_dict(checkpoint['state_dict']) + except: + state = model.state_dict() + for key in state.keys(): + state[key] = checkpoint['state_dict'][key[7:]] + model.load_state_dict(state) + + optimizer.load_state_dict(checkpoint['optimizer']) + best_acc = checkpoint['best_acc'] + if not config['loss']['use_center']: + log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'], resume=True) + if config['loss']['use_center']: + if os.path.exists(os.path.join(checkpoints, 'log_center.txt')): + log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'], resume=True) + else: + log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm']) + title = list(loss_bin.keys()) + title.extend(['val_loss','test_acc','best_acc']) + log_write.set_names(title) + else: + print('Training from scratch.') + log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm']) + title = list(loss_bin.keys()) + title.extend(['val_loss','test_acc','best_acc']) + log_write.set_names(title) + center_flag = False + center_model = None + if config['base']['finetune']: + start_epoch = 0 + optimizer.param_groups[0]['lr'] = 0.0001 + center_flag = True + center_model = copy.deepcopy(model) + for epoch in range(start_epoch,config['base']['n_epoch']): + model.train() + optimizer_decay(config, optimizer, epoch) + if config['loss']['use_center']: + optimizer_decay_center(config, optimizer_center, epoch) + loss_write,loss_flag = ModelTrain(train_data_loader,LabelConverter, model,center_model, criterion, optimizer, center_criterion,optimizer_center,center_flag,loss_bin, config, epoch) +# if loss_flag < config['loss']['min_score']: +# center_flag = True + if(epoch >= config['base']['start_val']): + model.eval() + val_acc,val_loss = ModelEval(test_data_loader,LabelConverter, model,criterion ,config) + print('val_acc:',val_acc,'val_loss',val_loss) + if (val_acc > best_acc): + save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'lr': config['optimizer']['base_lr'], + 'optimizer': optimizer.state_dict(), + 'best_acc': val_acc + }, checkpoints, config['base']['algorithm'] + '_best' + '.pth.tar') + best_acc = val_acc + + + loss_write.extend([val_loss,val_acc,best_acc]) + log_write.append(loss_write) + for key in loss_bin.keys(): + loss_bin[key].loss_clear() + if epoch%config['base']['save_epoch'] ==0: + save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'lr': config['optimizer']['base_lr'], + 'optimizer': optimizer.state_dict(), + 'best_acc': 0 + },checkpoints,config['base']['algorithm']+'_'+str(epoch)+'.pth.tar') + + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser(description='Hyperparams') + parser.add_argument('--config', help='config file path') + parser.add_argument('--log_str', help='log title') + args = parser.parse_args() + + + TrainValProgram(args) \ No newline at end of file