diff --git a/README.md b/README.md
index eac6f3b..21f6bc1 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,10 @@
***
最近跟新:
+- 2021.05.19 更新基于DBnet的多语种文本检测。
+- 2021.05.01 更新CRNN 训练,解决了多gpu训练问题,更换成lmdb训练,需要将图片先转成lmdb(在script文件夹中有多进程将图片转成lmdb的代码),做了一些训练优化,模型结构更改(训练时使用名字中带lmdb的yaml文件),实际训练效果如下表。
+- 2021.03.26 更新CRNN 训练效果,代码整理后上传
+- 2021.03.06 更新CRNN backbone resnet 和 mobilev3 以及配置文件
- 2020.12.22 更新CRNN+CTCLoss+CenterLoss训练
- 2020.09.18 更新文本检测说明文档
- 2020.09.12 更新DB,pse,pan,sast,crnn训练测试代码和预训练模型
@@ -27,6 +31,17 @@
- [ ] 训练通用化ocr模型
- [ ] 结合chinese_lite进行部署
- [ ] 手机端部署
+***
+### crnn模型效果(实验中)
+使用 MJSynth(MJ) 和 SynthText(ST) 训练,以batchsize=512训练,在以下数据集上测试:
+
+| 模型 |迭代次数| CUTE80 | IC03_867 |IC13_1015|IC13_857|IC15_1811|IC15_2077|IIIT5k_3000|SVT|SVTP|mean|
+|-|-|-|-|-|-|-|-|-|-|-|-|
+| resnet34+lstm+ctc |120000| 82.98| 91.92|90.93|91.59|73.10|67.98|90.16|85.16|78.29|83.56|
+| mobilev3_large+lstm+ctc | 210000| 73.61| 92.50|90.34|91.59|74.82|68.89|87.56|83.46|77.20|82.21|
+| mobilev3_small+lstm+ctc | 210000| 66.31| 90.77|88.76|91.13|73.66|69.52|88.80|84.54|72.24|80.64|
+
+
***
### 检测模型效果(实验中)
@@ -93,6 +108,18 @@
+***
+
+### Dbnet多语种文本检测效果
+
+#### 生成数据集:
+
+
+#### 公开数据集:
+
+
+
+
***
### 有问题及交流加微信
diff --git a/bg_img/1.jpg b/bg_img/1.jpg
new file mode 100644
index 0000000..cab5298
Binary files /dev/null and b/bg_img/1.jpg differ
diff --git a/bg_img/2.jpg b/bg_img/2.jpg
new file mode 100644
index 0000000..7d077d5
Binary files /dev/null and b/bg_img/2.jpg differ
diff --git a/bg_img/3.jpg b/bg_img/3.jpg
new file mode 100644
index 0000000..da162d0
Binary files /dev/null and b/bg_img/3.jpg differ
diff --git a/bg_img/4.jpg b/bg_img/4.jpg
new file mode 100644
index 0000000..98edb26
Binary files /dev/null and b/bg_img/4.jpg differ
diff --git a/bg_img/5.jpg b/bg_img/5.jpg
new file mode 100644
index 0000000..a9e3616
Binary files /dev/null and b/bg_img/5.jpg differ
diff --git a/bg_img/6.jpg b/bg_img/6.jpg
new file mode 100644
index 0000000..3601bd5
Binary files /dev/null and b/bg_img/6.jpg differ
diff --git a/bg_img/7.jpg b/bg_img/7.jpg
new file mode 100644
index 0000000..fc01ae2
Binary files /dev/null and b/bg_img/7.jpg differ
diff --git a/bg_img/8.jpg b/bg_img/8.jpg
new file mode 100644
index 0000000..a94db1d
Binary files /dev/null and b/bg_img/8.jpg differ
diff --git a/bg_img/9.jpg b/bg_img/9.jpg
new file mode 100644
index 0000000..103e601
Binary files /dev/null and b/bg_img/9.jpg differ
diff --git "a/checkpoint/\346\226\260\345\273\272\346\226\207\346\234\254\346\226\207\346\241\243.txt" "b/checkpoint/\346\226\260\345\273\272\346\226\207\346\234\254\346\226\207\346\241\243.txt"
deleted file mode 100644
index e69de29..0000000
diff --git a/config/det_DB_mobilev3.yaml b/config/det_DB_mobilev3.yaml
index b286185..bee1d35 100644
--- a/config/det_DB_mobilev3.yaml
+++ b/config/det_DB_mobilev3.yaml
@@ -9,12 +9,12 @@ base:
crop_shape: [640,640]
shrink_ratio: 0.4
n_epoch: 1200
- start_val: 400
+ start_val: 700
show_step: 20
checkpoints: ./checkpoint
save_epoch: 100
- restore: True
- restore_file : ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/DB_best.pth.tar
+ restore: False
+ restore_file : ./checkpoint/DB_best.pth.tar
backbone:
function: ptocr.model.backbone.det_mobilev3,mobilenet_v3_small
@@ -82,6 +82,6 @@ postprocess:
min_size: 3
infer:
- model_path: './checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200/DB_best.pth.tar'
+ model_path: './checkpoint/DB_best.pth.tar'
path: '/src/notebooks/detect_text/icdar2015/ch4_test_images'
save_path: './result'
diff --git a/config/det_DB_mobilev3_common.yaml b/config/det_DB_mobilev3_common.yaml
deleted file mode 100644
index 2de20ac..0000000
--- a/config/det_DB_mobilev3_common.yaml
+++ /dev/null
@@ -1,87 +0,0 @@
-base:
- gpu_id: '2'
- algorithm: DB
- pretrained: True
- in_channels: [24, 40, 48, 96]
- inner_channels: 96
- k: 50
- adaptive: True
- crop_shape: [640,640]
- shrink_ratio: 0.4
- n_epoch: 400
- start_val: 500
- show_step: 20
- checkpoints: ./checkpoint
- save_epoch: 1
- restore: False
- restore_file : ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_400/DB_35.pth.tar
-
-backbone:
- function: ptocr.model.backbone.det_mobilev3,mobilenet_v3_small
-
-head:
- function: ptocr.model.head.det_DBHead,DB_Head
-# function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head
-# function: ptocr.model.head.det_FPNHead,FPN_Head
-
-segout:
- function: ptocr.model.segout.det_DB_segout,SegDetector
-
-architectures:
- model_function: ptocr.model.architectures.det_model,DetModel
- loss_function: ptocr.model.architectures.det_model,DetLoss
-
-loss:
- function: ptocr.model.loss.db_loss,DBLoss
- l1_scale: 10
- bce_scale: 1
-
-#optimizer:
-# function: ptocr.optimizer,AdamDecay
-# base_lr: 0.002
-# beta1: 0.9
-# beta2: 0.999
-
-optimizer:
- function: ptocr.optimizer,SGDDecay
- base_lr: 0.002
- momentum: 0.99
- weight_decay: 0.00005
-
-optimizer_decay:
- function: ptocr.optimizer,adjust_learning_rate_poly
- factor: 0.9
-
-#optimizer_decay:
-# function: ptocr.optimizer,adjust_learning_rate
-# schedule: [1,2]
-# gama: 0.1
-
-trainload:
- function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrain
- train_file: /src/notebooks/chinese_recognize_data/detection/CommonData/train_list.txt
- num_workers: 10
- batch_size: 16
-
-testload:
- function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest
- test_file: /src/notebooks/detect_text/icdar2015/test_list.txt
- test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/
- test_size: 736
- stride: 32
- num_workers: 5
- batch_size: 4
-
-postprocess:
- function: ptocr.postprocess.DBpostprocess,DBPostProcess
- is_poly: False
- thresh: 0.2
- box_thresh: 0.4
- max_candidates: 1000
- unclip_ratio: 2
- min_size: 3
-
-infer:
- model_path: './checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200/DB_best.pth.tar'
- path: '/src/notebooks/detect_text/icdar2015/ch4_test_images'
- save_path: './result'
diff --git a/config/det_DB_mobilev3_pytorch_qua.yaml b/config/det_DB_mobilev3_pytorch_qua.yaml
deleted file mode 100644
index 9206e8f..0000000
--- a/config/det_DB_mobilev3_pytorch_qua.yaml
+++ /dev/null
@@ -1,88 +0,0 @@
-base:
- gpu_id: '2'
- algorithm: DB
- backend: 'qnnpack' # fbgemm
- pretrained: False
- in_channels: [24, 40, 48, 96]
- inner_channels: 96
- k: 50
- adaptive: True
- crop_shape: [640,640]
- shrink_ratio: 0.4
- n_epoch: 1200
- start_val: 400
- show_step: 20
- checkpoints: ./checkpoint
- save_epoch: 100
- restore: False
- restore_file : ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/DB_best.pth.tar
-
-backbone:
- function: ptocr.model.backbone.det_mobilev3_pytorch_qua,mobilenet_v3_small
-
-head:
- function: ptocr.model.head.det_DBHead_Qua,DB_Head
-# function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head
-# function: ptocr.model.head.det_FPNHead,FPN_Head
-
-segout:
- function: ptocr.model.segout.det_DB_segout_qua,SegDetector
-
-architectures:
- model_function: ptocr.model.architectures.det_model_q,DetModel
- loss_function: ptocr.model.architectures.det_model_q,DetLoss
-
-loss:
- function: ptocr.model.loss.db_loss,DBLoss
- l1_scale: 10
- bce_scale: 1
-
-#optimizer:
-# function: ptocr.optimizer,AdamDecay
-# base_lr: 0.002
-# beta1: 0.9
-# beta2: 0.999
-
-optimizer:
- function: ptocr.optimizer,SGDDecay
- base_lr: 0.002
- momentum: 0.99
- weight_decay: 0.00005
-
-optimizer_decay:
- function: ptocr.optimizer,adjust_learning_rate_poly
- factor: 0.9
-
-#optimizer_decay:
-# function: ptocr.optimizer,adjust_learning_rate
-# schedule: [1,2]
-# gama: 0.1
-
-trainload:
- function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrain
- train_file: /src/notebooks/detect_text/icdar2015/train_list.txt
- num_workers: 10
- batch_size: 16
-
-testload:
- function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest
- test_file: /src/notebooks/detect_text/icdar2015/test_list.txt
- test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/
- test_size: 736
- stride: 32
- num_workers: 5
- batch_size: 4
-
-postprocess:
- function: ptocr.postprocess.DBpostprocess,DBPostProcess
- is_poly: False
- thresh: 0.5
- box_thresh: 0.6
- max_candidates: 1000
- unclip_ratio: 2
- min_size: 3
-
-infer:
- model_path: './checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200/DB_best.pth.tar'
- path: '/src/notebooks/detect_text/icdar2015/ch4_test_images'
- save_path: './result'
diff --git a/config/det_DB_resnet50_3_3.yaml b/config/det_DB_resnet50_3_3.yaml
index 8b117b5..ff52b4a 100644
--- a/config/det_DB_resnet50_3_3.yaml
+++ b/config/det_DB_resnet50_3_3.yaml
@@ -1,15 +1,15 @@
base:
gpu_id: '0'
algorithm: DB
- pretrained: True
+ pretrained: False
in_channels: [256, 512, 1024, 2048]
inner_channels: 256
k: 50
adaptive: True
crop_shape: [640,640]
shrink_ratio: 0.4
- n_epoch: 1201
- start_val: 400
+ n_epoch: 600
+ start_val: 6000
show_step: 20
checkpoints: ./checkpoint
save_epoch: 100
@@ -17,11 +17,11 @@ base:
restore_file : ./DB.pth.tar
backbone:
- function: ptocr.model.backbone.det_resnet_3*3,resnet50
+ function: ptocr.model.backbone.det_resnet_3_3,resnet50
head:
- function: ptocr.model.head.det_DBHead,DB_Head
-# function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head
+# function: ptocr.model.head.det_DBHead,DB_Head
+ function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head
# function: ptocr.model.head.det_FPNHead,FPN_Head
segout:
@@ -59,7 +59,7 @@ optimizer_decay:
trainload:
function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrain
- train_file: /src/notebooks/detect_text/icdar2015/train_list.txt
+ train_file: /src/notebooks/MyworkData/huayandang/train_list.txt
num_workers: 10
batch_size: 8
@@ -75,10 +75,10 @@ testload:
postprocess:
function: ptocr.postprocess.DBpostprocess,DBPostProcess
is_poly: False
- thresh: 0.5
- box_thresh: 0.6
+ thresh: 0.2
+ box_thresh: 0.3
max_candidates: 1000
- unclip_ratio: 2
+ unclip_ratio: 1.5
min_size: 3
infer:
diff --git a/config/det_DB_resnet50_mul.yaml b/config/det_DB_resnet50_mul.yaml
new file mode 100644
index 0000000..e37dcc4
--- /dev/null
+++ b/config/det_DB_resnet50_mul.yaml
@@ -0,0 +1,89 @@
+base:
+ gpu_id: '1' # 设置训练的gpu id,多卡训练设置为 '0,1,2'
+ algorithm: DB # 算法名称
+ pretrained: True # 是否加载预训练
+ in_channels: [256, 512, 1024, 2048] #
+ inner_channels: 256 #
+ k: 50
+ n_class: 3
+ adaptive: True
+ crop_shape: [640,640] #训练时crop图片的大小
+ shrink_ratio: 0.4 # kernel向内收缩比率
+ n_epoch: 1200 # 训练的epoch
+ start_val: 400 #开始验证的epoch,如果不想验证直接设置数值大于n_epoch
+ show_step: 20 #设置迭代多少次输出一次loss
+ checkpoints: ./checkpoint #保存模型地址
+ save_epoch: 100 #设置每多少个epoch保存一次模型
+ restore: False #是否恢复训练
+ restore_file : ./DB.pth.tar #恢复训练所需加载模型的地址
+
+backbone:
+ function: ptocr.model.backbone.det_resnet,resnet50
+
+head:
+ function: ptocr.model.head.det_DBHead,DB_Head
+# function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head
+# function: ptocr.model.head.det_FPNHead,FPN_Head
+
+segout:
+ function: ptocr.model.segout.det_DB_segout,SegDetectorMul
+
+architectures:
+ model_function: ptocr.model.architectures.det_model,DetModel
+ loss_function: ptocr.model.architectures.det_model,DetLoss
+
+loss:
+ function: ptocr.model.loss.db_loss,DBLossMul
+ l1_scale: 10
+ bce_scale: 1
+ class_scale: 1
+
+#optimizer:
+# function: ptocr.optimizer,AdamDecay
+# base_lr: 0.002
+# beta1: 0.9
+# beta2: 0.999
+
+optimizer:
+ function: ptocr.optimizer,SGDDecay
+ base_lr: 0.002
+ momentum: 0.99
+ weight_decay: 0.0005
+
+optimizer_decay:
+ function: ptocr.optimizer,adjust_learning_rate_poly
+ factor: 0.9
+
+#optimizer_decay:
+# function: ptocr.optimizer,adjust_learning_rate
+# schedule: [1,2]
+# gama: 0.1
+
+trainload:
+ function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrainMul
+ train_file: /src/notebooks/fangxuwei_96/TextGenerator-master/output/train/train_list.txt
+ num_workers: 10
+ batch_size: 8
+
+testload:
+ function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest
+ test_file: /src/notebooks/detect_text/icdar2015/test_list.txt
+ test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/
+ test_size: 736
+ stride: 32
+ num_workers: 5
+ batch_size: 4
+
+postprocess:
+ function: ptocr.postprocess.DBpostprocess,DBPostProcessMul
+ is_poly: False #测试时,检测弯曲文本设置成 True,否则就是输出矩形框
+ thresh: 0.5
+ box_thresh: 0.6
+ max_candidates: 1000
+ unclip_ratio: 2
+ min_size: 3
+
+infer:
+ model_path: './checkpoint/ag_DB_bb_resnet50_he_DB_Head_bs_8_ep_601_train_mul/DB_best.pth.tar'
+ path: '/src/notebooks/fangxuwei_96/TextGenerator-master/output/img/'
+ save_path: './result'
diff --git a/config/det_SAST_resnet50_3_3_ori_dataload.yaml b/config/det_SAST_resnet50_3_3_ori_dataload.yaml
index 89c4660..a784869 100644
--- a/config/det_SAST_resnet50_3_3_ori_dataload.yaml
+++ b/config/det_SAST_resnet50_3_3_ori_dataload.yaml
@@ -5,7 +5,7 @@ base:
with_attention: True
crop_shape: [512,512]
n_epoch: 901
- start_val: 500
+ start_val: 5000
show_step: 20
checkpoints: ./checkpoint
save_epoch: 100
@@ -13,7 +13,7 @@ base:
restore_file : ./checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_12_ep_2000/SAST_best.pth.tar
backbone:
- function: ptocr.model.backbone.det_resnet_sast_3*3,resnet50
+ function: ptocr.model.backbone.det_resnet_sast_3_3,resnet50
head:
function: ptocr.model.head.det_SASTHead,SASTHead
@@ -67,7 +67,7 @@ optimizer_decay:
trainload:
function: ptocr.dataloader.DetLoad.SASTProcess_ori,SASTProcessTrain
- train_file: /src/notebooks/detect_text/icdar2015/train_list.txt
+ train_file: /src/notebooks/MyworkData/huayandang/train_list.txt
num_workers: 12
batch_size: 8
min_crop_side_ratio: 0.3
@@ -95,6 +95,6 @@ postprocess:
tcl_map_thresh: 0.7
infer:
- model_path: './checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_8_ep_1000/SAST_best.pth.tar'
- path: '/src/notebooks/detect_text/icdar2015/ch4_test_images'
+ model_path: './checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_8_ep_901/SAST_400.pth.tar'
+ path: '/src/notebooks/MyworkData/huayandang/train'
save_path: './result'
diff --git a/config/rec_CRNN_mobilev3.yaml b/config/rec_CRNN_mobilev3.yaml
deleted file mode 100644
index 3ce7a17..0000000
--- a/config/rec_CRNN_mobilev3.yaml
+++ /dev/null
@@ -1,78 +0,0 @@
-base:
- gpu_id: '0'
- algorithm: CRNN
- pretrained: True
- inchannel: 96
- hiddenchannel: 48
- img_shape: [32,280]
- is_gray: True
- use_conv: False
- use_attention: False
- use_lstm: False
- lstm_num: 2
- classes: 1000
- n_epoch: 20
- start_val: 1
- show_step: 20
- checkpoints: ./checkpoint
- save_epoch: 1
- show_num: 10
- restore: False
- restore_file : ./DB.pth.tar
-
-backbone:
- function: ptocr.model.backbone.reg_mobilev3,mobilenet_v3_small
-
-head:
- function: ptocr.model.head.rec_CRNNHead,CRNN_Head
-
-architectures:
- model_function: ptocr.model.architectures.rec_model,RecModel
- loss_function: ptocr.model.architectures.rec_model,RecLoss
-
-loss:
- function: ptocr.model.loss.ctc_loss,CTCLoss
- reduction: 'mean'
-
-optimizer:
- function: ptocr.optimizer,AdamDecay
- base_lr: 0.001
- beta1: 0.9
- beta2: 0.999
- weight_decay: 0.00005
-
-# optimizer:
-# function: ptocr.optimizer,SGDDecay
-# base_lr: 0.002
-# momentum: 0.99
-# weight_decay: 0.00005
-
-# optimizer_decay:
-# function: ptocr.optimizer,adjust_learning_rate_poly
-# factor: 0.9
-
-optimizer_decay:
- function: ptocr.optimizer,adjust_learning_rate
- schedule: [5,8,11,14]
- gama: 0.1
-
-trainload:
- function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTrain
- train_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_file/train.txt
- key_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_file/key.txt
- num_workers: 10
- batch_size: 256
-
-testload:
- function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTest
- test_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_file/val.txt
- num_workers: 5
- batch_size: 256
-
-label_transform:
- function: ptocr.utils.transform_label,strLabelConverter
-
-infer:
- model_path: ''
- path: ''
- save_path: ''
diff --git a/config/rec_CRNN_mobilev3_large_english_all.yaml b/config/rec_CRNN_mobilev3_large_english_all.yaml
new file mode 100644
index 0000000..7cab74a
--- /dev/null
+++ b/config/rec_CRNN_mobilev3_large_english_all.yaml
@@ -0,0 +1,102 @@
+base:
+ gpu_id: '0,1'
+ algorithm: CRNN
+ pretrained: False
+ inchannel: 960
+ hiddenchannel: 96
+ img_shape: [32,100]
+ is_gray: True
+ use_conv: False
+ use_attention: False
+ use_lstm: True
+ lstm_num: 2
+ classes: 1000
+ max_iters: 300000
+ eval_iter: 10000
+ show_step: 100
+ checkpoints: ./checkpoint
+ save_epoch: 1
+ show_num: 10
+ restore: False
+ finetune: False
+ restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar
+
+backbone:
+ function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_large
+
+head:
+ function: ptocr.model.head.rec_CRNNHead,CRNN_Head
+
+architectures:
+ model_function: ptocr.model.architectures.rec_model,RecModel
+ loss_function: ptocr.model.architectures.rec_model,RecLoss
+
+loss:
+ function: ptocr.model.loss.ctc_loss,CTCLoss
+ use_ctc_weight: False
+ reduction: 'mean'
+ center_function: ptocr.model.loss.centerloss,CenterLoss
+ use_center: False
+ center_lr: 0.5
+ label_score: 0.95
+# min_score: 0.01
+ weight_center: 0.000001
+
+
+optimizer:
+ function: ptocr.optimizer,AdamDecay
+ base_lr: 0.001
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.00005
+
+# optimizer:
+# function: ptocr.optimizer,SGDDecay
+# base_lr: 0.002
+# momentum: 0.99
+# weight_decay: 0.00005
+
+# optimizer_decay:
+# function: ptocr.optimizer,adjust_learning_rate_poly
+# factor: 0.9
+
+optimizer_decay:
+ function: ptocr.optimizer,adjust_learning_rate
+ schedule: [100000,200000]
+ gama: 0.1
+
+optimizer_decay_center:
+ function: ptocr.optimizer,adjust_learning_rate_center
+ schedule: [100000,200000]
+ gama: 0.1
+
+trainload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad
+ train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth']
+ batch_ratio: [0.5,0.5]
+ key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt
+ bg_path: ./bg_img/
+ num_workers: 16
+ batch_size: 512
+
+testload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess1,CRNNProcessTest
+ test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt
+ num_workers: 8
+ batch_size: 256
+
+
+label_transform:
+ function: ptocr.utils.transform_label,strLabelConverter
+
+transform:
+ function: ptocr.dataloader.RecLoad.DataAgument,transform_label
+ t_type: lower
+ char_type: En
+
+infer:
+# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar'
+ model_path: './checkpoint/ag_CRNN_bb_mobilenet_v3_large_he_CRNN_Head_bs_512_ep_300000_mobilev2_alldata/CRNN_210000.pth.tar'
+# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg'
+ path: './english_val_img/'
+ save_path: ''
diff --git a/config/rec_CRNN_mobilev3_large_english_lmdb.yaml b/config/rec_CRNN_mobilev3_large_english_lmdb.yaml
new file mode 100644
index 0000000..9479d4a
--- /dev/null
+++ b/config/rec_CRNN_mobilev3_large_english_lmdb.yaml
@@ -0,0 +1,77 @@
+base:
+ gpu_id: '0'
+ algorithm: CRNN
+ pretrained: False
+ inchannel: 960
+ hiddenchannel: 96
+ img_shape: [32,100]
+ is_gray: True
+ use_attention: False
+ use_lstm: True
+ lstm_num: 2
+ n_epoch: 8
+ start_val: 0
+ show_step: 50
+ checkpoints: ./checkpoint
+ save_epoch: 1
+ show_num: 10
+ restore: False
+ finetune: False
+ restore_file : ./checkpoint/
+
+backbone:
+ function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_large
+
+head:
+ function: ptocr.model.head.rec_CRNNHead,CRNN_Head
+
+architectures:
+ model_function: ptocr.model.architectures.rec_model,RecModel
+ loss_function: ptocr.model.architectures.rec_model,RecLoss
+
+loss:
+ function: ptocr.model.loss.ctc_loss,CTCLoss
+ ctc_type: 'warpctc' # torchctc
+ use_ctc_weight: False
+ loss_title: ['ctc_loss']
+
+optimizer:
+ function: ptocr.optimizer,AdamDecay
+ base_lr: 0.001
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.00005
+
+
+optimizer_decay:
+ function: ptocr.optimizer,adjust_learning_rate
+ schedule: [4,6]
+ gama: 0.1
+
+
+trainload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad
+ train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/'
+ key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt
+ bg_path: ./bg_img/
+ num_workers: 10
+ batch_size: 512
+
+valload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad
+ val_file: '/src/notebooks/IIIT5k_3000/lmdb/'
+ num_workers: 5
+ batch_size: 256
+
+label_transform:
+ function: ptocr.utils.transform_label,strLabelConverter
+ label_function: ptocr.dataloader.RecLoad.DataAgument,transform_label
+ t_type: lower
+ char_type: En
+
+infer:
+# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar'
+ model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar'
+# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg'
+ path: './english_val_img/SVT/image/'
+ save_path: ''
diff --git a/config/rec_CRNN_mobilev3_small_english_all.yaml b/config/rec_CRNN_mobilev3_small_english_all.yaml
new file mode 100644
index 0000000..4ec984f
--- /dev/null
+++ b/config/rec_CRNN_mobilev3_small_english_all.yaml
@@ -0,0 +1,104 @@
+base:
+ gpu_id: '1'
+ algorithm: CRNN
+ pretrained: False
+ inchannel: 576
+ hiddenchannel: 48
+ img_shape: [32,100]
+ is_gray: True
+ use_conv: False
+ use_attention: False
+ use_lstm: True
+ lstm_num: 2
+ classes: 1000
+ max_iters: 300000
+ eval_iter: 10000
+ show_step: 100
+ checkpoints: ./checkpoint
+ save_epoch: 1
+ show_num: 10
+ restore: False
+ finetune: False
+ restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar
+
+backbone:
+ function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_small
+
+head:
+ function: ptocr.model.head.rec_CRNNHead,CRNN_Head
+
+architectures:
+ model_function: ptocr.model.architectures.rec_model,RecModel
+ loss_function: ptocr.model.architectures.rec_model,RecLoss
+
+loss:
+ function: ptocr.model.loss.ctc_loss,CTCLoss
+ use_ctc_weight: False
+ reduction: 'mean'
+ center_function: ptocr.model.loss.centerloss,CenterLoss
+ use_center: False
+ center_lr: 0.5
+ label_score: 0.95
+# min_score: 0.01
+ weight_center: 0.000001
+
+
+optimizer:
+ function: ptocr.optimizer,AdamDecay
+ base_lr: 0.001
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.00005
+
+# optimizer:
+# function: ptocr.optimizer,SGDDecay
+# base_lr: 0.002
+# momentum: 0.99
+# weight_decay: 0.00005
+
+# optimizer_decay:
+# function: ptocr.optimizer,adjust_learning_rate_poly
+# factor: 0.9
+
+optimizer_decay:
+ function: ptocr.optimizer,adjust_learning_rate
+ schedule: [100000,200000]
+ gama: 0.1
+
+optimizer_decay_center:
+ function: ptocr.optimizer,adjust_learning_rate_center
+ schedule: [100000,200000]
+ gama: 0.1
+
+trainload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad
+ train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth']
+ batch_ratio: [0.5,0.5]
+ key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt
+ bg_path: ./bg_img/
+ num_workers: 16
+ batch_size: 512
+
+valload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess1,GetValDataLoad
+ root: '/src/notebooks/pytorchOCR-master/english_val_img'
+ dir: ['CUTE80','IC03_867','IC13_1015','IC13_857','IC15_1811','IIIT5k_3000','SVT','SVTP','IC15_2077']
+ test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt
+ num_workers: 2
+ batch_size: 1
+
+
+label_transform:
+ function: ptocr.utils.transform_label,strLabelConverter
+
+transform:
+ function: ptocr.dataloader.RecLoad.DataAgument,transform_label
+ t_type: lower
+ char_type: En
+
+infer:
+# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar'
+ model_path: './checkpoint/ag_CRNN_bb_mobilenet_v3_small_he_CRNN_Head_bs_512_ep_300000_mobilev2_small_alldata/CRNN_210000.pth.tar'
+# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg'
+ path: './english_val_img/'
+ save_path: ''
diff --git a/config/rec_CRNN_mobilev3_small_english_lmdb.yaml b/config/rec_CRNN_mobilev3_small_english_lmdb.yaml
new file mode 100644
index 0000000..7cf838b
--- /dev/null
+++ b/config/rec_CRNN_mobilev3_small_english_lmdb.yaml
@@ -0,0 +1,77 @@
+base:
+ gpu_id: '0'
+ algorithm: CRNN
+ pretrained: False
+ inchannel: 576
+ hiddenchannel: 48
+ img_shape: [32,100]
+ is_gray: True
+ use_attention: False
+ use_lstm: True
+ lstm_num: 2
+ n_epoch: 8
+ start_val: 0
+ show_step: 50
+ checkpoints: ./checkpoint
+ save_epoch: 1
+ show_num: 10
+ restore: False
+ finetune: False
+ restore_file : ./checkpoint/
+
+backbone:
+ function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_small
+
+head:
+ function: ptocr.model.head.rec_CRNNHead,CRNN_Head
+
+architectures:
+ model_function: ptocr.model.architectures.rec_model,RecModel
+ loss_function: ptocr.model.architectures.rec_model,RecLoss
+
+loss:
+ function: ptocr.model.loss.ctc_loss,CTCLoss
+ ctc_type: 'warpctc' # torchctc
+ use_ctc_weight: False
+ loss_title: ['ctc_loss']
+
+optimizer:
+ function: ptocr.optimizer,AdamDecay
+ base_lr: 0.001
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.00005
+
+
+optimizer_decay:
+ function: ptocr.optimizer,adjust_learning_rate
+ schedule: [4,6]
+ gama: 0.1
+
+
+trainload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad
+ train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/'
+ key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt
+ bg_path: ./bg_img/
+ num_workers: 10
+ batch_size: 512
+
+valload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad
+ val_file: '/src/notebooks/IIIT5k_3000/lmdb/'
+ num_workers: 5
+ batch_size: 256
+
+label_transform:
+ function: ptocr.utils.transform_label,strLabelConverter
+ label_function: ptocr.dataloader.RecLoad.DataAgument,transform_label
+ t_type: lower
+ char_type: En
+
+infer:
+# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar'
+ model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar'
+# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg'
+ path: './english_val_img/SVT/image/'
+ save_path: ''
diff --git a/config/rec_CRNN_resnet34_english_lmdb.yaml b/config/rec_CRNN_resnet34_english_lmdb.yaml
new file mode 100644
index 0000000..f5da317
--- /dev/null
+++ b/config/rec_CRNN_resnet34_english_lmdb.yaml
@@ -0,0 +1,77 @@
+base:
+ gpu_id: '0'
+ algorithm: CRNN
+ pretrained: False
+ inchannel: 512
+ hiddenchannel: 128
+ img_shape: [32,100]
+ is_gray: True
+ use_attention: False
+ use_lstm: True
+ lstm_num: 2
+ n_epoch: 8
+ start_val: 0
+ show_step: 50
+ checkpoints: ./checkpoint
+ save_epoch: 1
+ show_num: 10
+ restore: True
+ finetune: False
+ restore_file : ./checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_no_attention_no_weight/CRNN_best.pth.tar
+
+backbone:
+ function: ptocr.model.backbone.reg_resnet_bd,resnet34
+
+head:
+ function: ptocr.model.head.rec_CRNNHead,CRNN_Head
+
+architectures:
+ model_function: ptocr.model.architectures.rec_model,RecModel
+ loss_function: ptocr.model.architectures.rec_model,RecLoss
+
+loss:
+ function: ptocr.model.loss.ctc_loss,CTCLoss
+ ctc_type: 'warpctc' # torchctc
+ use_ctc_weight: False
+ loss_title: ['ctc_loss']
+
+optimizer:
+ function: ptocr.optimizer,AdamDecay
+ base_lr: 0.001
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.00005
+
+
+optimizer_decay:
+ function: ptocr.optimizer,adjust_learning_rate
+ schedule: [4,6]
+ gama: 0.1
+
+
+trainload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad
+ train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/'
+ key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt
+ bg_path: ./bg_img/
+ num_workers: 16
+ batch_size: 512
+
+valload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad
+ val_file: '/src/notebooks/IIIT5k_3000/lmdb/'
+ num_workers: 5
+ batch_size: 256
+
+label_transform:
+ function: ptocr.utils.transform_label,strLabelConverter
+ label_function: ptocr.dataloader.RecLoad.DataAgument,transform_label
+ t_type: lower
+ char_type: En
+
+infer:
+# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar'
+ model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar'
+# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg'
+ path: './english_val_img/SVT/image/'
+ save_path: ''
diff --git a/config/rec_CRNN_ori.yaml b/config/rec_CRNN_resnet_english.yaml
similarity index 59%
rename from config/rec_CRNN_ori.yaml
rename to config/rec_CRNN_resnet_english.yaml
index c2264fe..fc51a28 100644
--- a/config/rec_CRNN_ori.yaml
+++ b/config/rec_CRNN_resnet_english.yaml
@@ -1,28 +1,28 @@
base:
- gpu_id: '1'
+ gpu_id: '0,1'
algorithm: CRNN
pretrained: False
inchannel: 512
hiddenchannel: 128
- img_shape: [32,200]
+ img_shape: [32,100]
is_gray: True
use_conv: False
use_attention: False
- use_lstm: False
- lstm_num: 1
+ use_lstm: True
+ lstm_num: 2
classes: 1000
- n_epoch: 20
+ n_epoch: 8
start_val: 0
- show_step: 20
+ show_step: 100
checkpoints: ./checkpoint
save_epoch: 1
show_num: 10
- restore: False
- finetune: False
- restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_test_center_3/CRNN_best_ori.pth.tar
+ restore: True
+ finetune: True
+ restore_file : ./checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_256_ep_20_no_channel_timestep_rnn/CRNN_best.pth.tar
backbone:
- function: ptocr.model.backbone.rec_crnn_backbone,rec_crnn_backbone
+ function: ptocr.model.backbone.reg_resnet_bd,resnet34
head:
function: ptocr.model.head.rec_CRNNHead,CRNN_Head
@@ -33,13 +33,14 @@ architectures:
loss:
function: ptocr.model.loss.ctc_loss,CTCLoss
- reduction: 'sum'
+ use_ctc_weight: True
+ reduction: 'none'
center_function: ptocr.model.loss.centerloss,CenterLoss
- use_center: False
+ use_center: True
center_lr: 0.5
label_score: 0.95
# min_score: 0.01
- weight_center: 0.000001
+ weight_center: 0.001
optimizer:
@@ -61,24 +62,25 @@ optimizer:
optimizer_decay:
function: ptocr.optimizer,adjust_learning_rate
- schedule: [5,10,15]
+ schedule: [4,6]
gama: 0.1
optimizer_decay_center:
function: ptocr.optimizer,adjust_learning_rate_center
- schedule: [6,10,15]
+ schedule: [4,6]
gama: 0.1
trainload:
- function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTrain
- train_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_center/train.txt
- key_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_center/key.txt
+ function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTrainLmdb
+ train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/'
+ key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt
+ bg_path: ./bg_img/
num_workers: 10
- batch_size: 256
+ batch_size: 512
testload:
function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTest
- test_file: /src/notebooks/fangxuwei_96/crnn-master/train_data/data/train_center/val.txt
+ test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt
num_workers: 5
batch_size: 256
@@ -87,7 +89,8 @@ label_transform:
function: ptocr.utils.transform_label,strLabelConverter
infer:
- model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_test_center_3/CRNN_best.pth.tar'
-# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_no_center/CRNN_best.pth.tar'
- path: '/src/notebooks/fangxuwei_96/crnn-master/train_data/data/gen_data_sim/center_test/default'
+# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar'
+ model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar'
+# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg'
+ path: './english_val_img/SVT/image/'
save_path: ''
diff --git a/config/rec_CRNN_resnet_english_all.yaml b/config/rec_CRNN_resnet_english_all.yaml
new file mode 100644
index 0000000..b0e3771
--- /dev/null
+++ b/config/rec_CRNN_resnet_english_all.yaml
@@ -0,0 +1,102 @@
+base:
+ gpu_id: '0,1'
+ algorithm: CRNN
+ pretrained: False
+ inchannel: 512
+ hiddenchannel: 256
+ img_shape: [32,100]
+ is_gray: True
+ use_conv: False
+ use_attention: False
+ use_lstm: True
+ lstm_num: 2
+ classes: 1000
+ max_iters: 200000
+ eval_iter: 10000
+ show_step: 100
+ checkpoints: ./checkpoint
+ save_epoch: 1
+ show_num: 10
+ restore: False
+ finetune: False
+ restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar
+
+backbone:
+ function: ptocr.model.backbone.reg_resnet_bd,resnet34
+
+head:
+ function: ptocr.model.head.rec_CRNNHead,CRNN_Head
+
+architectures:
+ model_function: ptocr.model.architectures.rec_model,RecModel
+ loss_function: ptocr.model.architectures.rec_model,RecLoss
+
+loss:
+ function: ptocr.model.loss.ctc_loss,CTCLoss
+ use_ctc_weight: False
+ reduction: 'mean'
+ center_function: ptocr.model.loss.centerloss,CenterLoss
+ use_center: False
+ center_lr: 0.5
+ label_score: 0.95
+# min_score: 0.01
+ weight_center: 0.000001
+
+
+optimizer:
+ function: ptocr.optimizer,AdamDecay
+ base_lr: 0.001
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.00005
+
+# optimizer:
+# function: ptocr.optimizer,SGDDecay
+# base_lr: 0.002
+# momentum: 0.99
+# weight_decay: 0.00005
+
+# optimizer_decay:
+# function: ptocr.optimizer,adjust_learning_rate_poly
+# factor: 0.9
+
+optimizer_decay:
+ function: ptocr.optimizer,adjust_learning_rate
+ schedule: [80000,160000]
+ gama: 0.1
+
+optimizer_decay_center:
+ function: ptocr.optimizer,adjust_learning_rate_center
+ schedule: [80000,160000]
+ gama: 0.1
+
+trainload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad
+ train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth']
+ batch_ratio: [0.5,0.5]
+ key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt
+ bg_path: ./bg_img/
+ num_workers: 16
+ batch_size: 512
+
+testload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess1,CRNNProcessTest
+ test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt
+ num_workers: 8
+ batch_size: 256
+
+
+label_transform:
+ function: ptocr.utils.transform_label,strLabelConverter
+
+transform:
+ function: ptocr.dataloader.RecLoad.DataAgument,transform_label
+ t_type: lower
+ char_type: En
+
+infer:
+# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar'
+ model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_200000_alldata/CRNN_120000.pth.tar'
+# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg'
+ path: './english_val_img/'
+ save_path: ''
diff --git a/config/rec_CRNN_vgg16_bn.yaml b/config/rec_CRNN_vgg16_bn.yaml
deleted file mode 100644
index e2f4747..0000000
--- a/config/rec_CRNN_vgg16_bn.yaml
+++ /dev/null
@@ -1,78 +0,0 @@
-base:
- gpu_id: '1'
- algorithm: CRNN
- pretrained: True
- inchannel: 512
- hiddenchannel: 128
- img_shape: [32,100]
- is_gray: True
- use_conv: False
- use_attention: False
- use_lstm: False
- lstm_num: 2
- classes: 1000
- n_epoch: 100
- start_val: 10
- show_step: 20
- checkpoints: ./checkpoint
- save_epoch: 1
- show_num: 10
- restore: False
- restore_file : ./DB.pth.tar
-
-backbone:
- function: ptocr.model.backbone.rec_vgg,vgg16_bn
-
-head:
- function: ptocr.model.head.rec_CRNNHead,CRNN_Head
-
-architectures:
- model_function: ptocr.model.architectures.rec_model,RecModel
- loss_function: ptocr.model.architectures.rec_model,RecLoss
-
-loss:
- function: ptocr.model.loss.ctc_loss,CTCLoss
- reduction: 'mean'
-
-optimizer:
- function: ptocr.optimizer,AdamDecay
- base_lr: 0.001
- beta1: 0.9
- beta2: 0.999
- weight_decay: 0.00005
-
-# optimizer:
-# function: ptocr.optimizer,SGDDecay
-# base_lr: 0.002
-# momentum: 0.99
-# weight_decay: 0.00005
-
-optimizer_decay:
- function: ptocr.optimizer,adjust_learning_rate_poly
- factor: 0.9
-
-#optimizer_decay:
-# function: ptocr.optimizer,adjust_learning_rate
-# schedule: [1,2]
-# gama: 0.1
-
-trainload:
- function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTrain
- train_file: /src/notebooks/detect_text/icdar2015/recognize/train_list.txt
- key_file: /src/notebooks/detect_text/icdar2015/recognize/key.txt
- num_workers: 10
- batch_size: 32
-
-testload:
- function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTest
- test_file: /src/notebooks/detect_text/icdar2015/recognize/test_list.txt
- num_workers: 5
- batch_size: 32
-
-label_transform:
- function: ptocr.utils.transform_label,strLabelConverter
-
-infer:
- model_path: ''
- path: ''
- save_path: ''
diff --git a/config/rec_FC_resnet_english_all.yaml b/config/rec_FC_resnet_english_all.yaml
new file mode 100644
index 0000000..2a769de
--- /dev/null
+++ b/config/rec_FC_resnet_english_all.yaml
@@ -0,0 +1,107 @@
+base:
+ gpu_id: '0,1'
+ algorithm: FC
+ pretrained: False
+ in_channels: 2048
+ out_channels: 1024
+ ignore_index: 37
+ max_length: 25
+ img_shape: [32,100]
+ is_gray: True
+ use_conv: False
+ use_attention: False
+ use_lstm: True
+ lstm_num: 2
+ num_class: 36
+ start_iters: 0
+ max_iters: 300000
+ eval_iter: 10000
+ show_step: 100
+ checkpoints: ./checkpoint
+ save_epoch: 1
+ show_num: 10
+ restore: False
+ finetune: False
+ restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar
+
+backbone:
+ function: ptocr.model.backbone.reg_resnet_bd,resnet50
+
+head:
+ function: ptocr.model.head.rec_FCHead,FC_Head
+
+architectures:
+ model_function: ptocr.model.architectures.rec_model,RecModel
+ loss_function: ptocr.model.architectures.rec_model,RecLoss
+
+loss:
+ function: ptocr.model.loss.fc_loss,FCLoss
+ use_ctc_weight: False
+ reduction: 'mean'
+ center_function: ptocr.model.loss.centerloss,CenterLoss
+ use_center: False
+ center_lr: 0.5
+ label_score: 0.95
+# min_score: 0.01
+ weight_center: 0.000001
+
+
+optimizer:
+ function: ptocr.optimizer,AdamDecay
+ base_lr: 0.001
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.00005
+
+# optimizer:
+# function: ptocr.optimizer,SGDDecay
+# base_lr: 0.002
+# momentum: 0.99
+# weight_decay: 0.00005
+
+# optimizer_decay:
+# function: ptocr.optimizer,adjust_learning_rate_poly
+# factor: 0.9
+
+optimizer_decay:
+ function: ptocr.optimizer,adjust_learning_rate
+ schedule: [100000,200000]
+ gama: 0.1
+
+optimizer_decay_center:
+ function: ptocr.optimizer,adjust_learning_rate_center
+ schedule: [80000,160000]
+ gama: 0.1
+
+trainload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad
+ train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth']
+ batch_ratio: [0.5,0.5]
+ key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt
+ bg_path: ./bg_img/
+ num_workers: 16
+ batch_size: 256
+
+valload:
+ function: ptocr.dataloader.RecLoad.CRNNProcess1,GetValDataLoad
+ root: '/src/notebooks/pytorchOCR-master/english_val_img'
+ dir: ['CUTE80','IC03_867','IC13_1015','IC13_857','IC15_1811','IIIT5k_3000','SVT','SVTP','IC15_2077']
+ test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt
+ num_workers: 2
+ batch_size: 1
+
+
+label_transform:
+ function: ptocr.utils.transform_label,FCConverter
+
+transform:
+ function: ptocr.dataloader.RecLoad.DataAgument,transform_label
+ t_type: lower
+ char_type: En
+
+infer:
+# model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar'
+ model_path: './checkpoint/ag_FC_bb_resnet34_he_FC_Head_bs_128_ep_200000_FC/FC_190000.pth.tar'
+# path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg'
+ path: './english_val_img/'
+ save_path: ''
diff --git "a/doc/md/\346\226\207\346\234\254\350\257\206\345\210\253\350\256\255\347\273\203\346\226\207\346\241\243.md" "b/doc/md/\346\226\207\346\234\254\350\257\206\345\210\253\350\256\255\347\273\203\346\226\207\346\241\243.md"
index 6a9ca23..da5a0fc 100644
--- "a/doc/md/\346\226\207\346\234\254\350\257\206\345\210\253\350\256\255\347\273\203\346\226\207\346\241\243.md"
+++ "b/doc/md/\346\226\207\346\234\254\350\257\206\345\210\253\350\256\255\347\273\203\346\226\207\346\241\243.md"
@@ -4,24 +4,16 @@
需要一个train_list.txt[示例](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/doc/example/rec_train_list.txt) , 格式:图片绝对路径+\t+label。 具体可参照项目中data/example中例子。
如果训练过程中需要做验证,需要制作相同的数据格式有一个test_list.txt[示例](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/doc/example/rec_test_list.txt)。
-#### 正常训练模型(以rec_CRNN_ori.yaml为例)
-1. 将yaml中base下的restore,finetune置为False,loss下的use_center置为False,修改数据路径还有其他参数。
-2. 运行下面命令
+#### 训练模型
+1. 修改./config中对应算法的yaml中参数,基本上只需修改数据路径即可。
+2. 在./tools/rec_train.py最下面打开不同的config中的yaml对应不同的算法
+3. 运行下面命令
```
-python3 ./tools/rec_train.py --config ./config/rec_CRNN_ori.yaml --log_str log
-
-
-#### CenterLoss训练模型(以rec_CRNN_ori.yaml为例)
-1. 将yaml中base下的restore,finetune置为True,loss下的use_center置为True,将正常训练得到的最优模型文件地址赋给base下的restore_file。
-2. 运行下面命令
-
-```
-python3 ./tools/rec_train.py --config ./config/rec_CRNN_ori.yaml --log_str log
+python3 ./tools/rec_train.py
```
#### 测试模型
-1. 将训练好的模型赋给yaml中infer下的model_path,图片地址赋给path
-2. 运行下面命令
+1. 运行下面命令
```
python3 ./tools/rec_infer.py
diff --git a/doc/show/1.jpg b/doc/show/1.jpg
new file mode 100644
index 0000000..b8a957b
Binary files /dev/null and b/doc/show/1.jpg differ
diff --git a/doc/show/2.jpg b/doc/show/2.jpg
new file mode 100644
index 0000000..bebb5cf
Binary files /dev/null and b/doc/show/2.jpg differ
diff --git a/doc/show/3.jpg b/doc/show/3.jpg
new file mode 100644
index 0000000..3e80900
Binary files /dev/null and b/doc/show/3.jpg differ
diff --git a/finetune_prune_model.sh b/finetune_prune_model.sh
deleted file mode 100644
index 6728b83..0000000
--- a/finetune_prune_model.sh
+++ /dev/null
@@ -1 +0,0 @@
-python3 tools/det_train.py --config ./config/det_DB_mobilev3.yaml --log_str total_prune_20201015_distil3 --pruned_model_dict_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/pruned/pruned_dict.dict --prune_model_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/pruned/pruned_dict.pth --prune_type total --n_epoch 200 --start_val 30 --base_lr 0.0008 --gpu_id 2 --t_ratio 0.1 --t_model_path ./checkpoint/ag_DB_bb_resnet50_he_DB_Head_bs_8_ep_1201/DB_best.pth.tar --t_config ./config/det_DB_resnet50_3_3.yaml
\ No newline at end of file
diff --git a/infer.sh b/infer.sh
deleted file mode 100644
index c111f04..0000000
--- a/infer.sh
+++ /dev/null
@@ -1 +0,0 @@
-python3 ./tools/det_infer.py --config ./config/det_DB_mobilev3.yaml --model_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200/DB_best.pth.tar --img_path /src/notebooks/detect_text/icdar2015/ch4_test_images --result_save_path ./result --onnx_path ./onnx/DBnet.onnx --trt_path ./onnx/DBnet_batch.engine --batch_size 2 --max_size 1536 --add_padding
\ No newline at end of file
diff --git "a/pre_model/\346\226\260\345\273\272\346\226\207\346\234\254\346\226\207\346\241\243.txt" "b/pre_model/\346\226\260\345\273\272\346\226\207\346\234\254\346\226\207\346\241\243.txt"
deleted file mode 100644
index e69de29..0000000
diff --git a/ptocr/__pycache__/__init__.cpython-36.pyc b/ptocr/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..04928da
Binary files /dev/null and b/ptocr/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/__pycache__/optimizer.cpython-36.pyc b/ptocr/__pycache__/optimizer.cpython-36.pyc
new file mode 100644
index 0000000..c85c005
Binary files /dev/null and b/ptocr/__pycache__/optimizer.cpython-36.pyc differ
diff --git a/ptocr/dataloader/DetLoad/DBProcess.py b/ptocr/dataloader/DetLoad/DBProcess.py
index 7fbe42b..749baed 100644
--- a/ptocr/dataloader/DetLoad/DBProcess.py
+++ b/ptocr/dataloader/DetLoad/DBProcess.py
@@ -94,7 +94,88 @@ def __getitem__(self, index):
return img,gt,gt_mask,thresh_map,thresh_mask
+class DBProcessTrainMul(data.Dataset):
+ def __init__(self,config):
+ super(DBProcessTrainMul,self).__init__()
+ self.crop_shape = config['base']['crop_shape']
+ self.MBM = MakeBorderMap()
+ self.TSM = Random_Augment(self.crop_shape)
+ self.MSM = MakeSegMap(shrink_ratio = config['base']['shrink_ratio'])
+ img_list, label_list = self.get_base_information(config['trainload']['train_file'])
+ self.img_list = img_list
+ self.label_list = label_list
+
+ def order_points(self, pts):
+ rect = np.zeros((4, 2), dtype="float32")
+ s = pts.sum(axis=1)
+ rect[0] = pts[np.argmin(s)]
+ rect[2] = pts[np.argmax(s)]
+ diff = np.diff(pts, axis=1)
+ rect[1] = pts[np.argmin(diff)]
+ rect[3] = pts[np.argmax(diff)]
+ return rect
+
+ def get_bboxes(self,gt_path):
+ polys = []
+ tags = []
+ classes = []
+ with open(gt_path, 'r', encoding='utf-8') as fid:
+ lines = fid.readlines()
+ for line in lines:
+ line = line.replace('\ufeff', '').replace('\xef\xbb\xbf', '')
+ gt = line.split(',')
+ if "#" in gt[-1]:
+ tags.append(True)
+ classes.append(-2)
+ else:
+ tags.append(False)
+ classes.append(int(gt[-1]))
+ # box = [int(gt[i]) for i in range(len(gt)//2*2)]
+ box = [int(gt[i]) for i in range(8)]
+ polys.append(box)
+ return np.array(polys), tags, classes
+
+ def get_base_information(self,train_txt_file):
+ label_list = []
+ img_list = []
+ with open(train_txt_file,'r',encoding='utf-8') as fid:
+ lines = fid.readlines()
+ for line in lines:
+ line = line.strip('\n').split('\t')
+ img_list.append(line[0])
+ result = self.get_bboxes(line[1])
+ label_list.append(result)
+ return img_list,label_list
+
+ def __len__(self):
+ return len(self.img_list)
+
+ def __getitem__(self, index):
+
+ img = Image.open(self.img_list[index]).convert('RGB')
+ img = np.array(img)[:,:,::-1]
+
+ polys, dontcare, classes = self.label_list[index]
+
+ img, polys = self.TSM.random_scale(img, polys, self.crop_shape[0])
+ img, polys = self.TSM.random_rotate(img, polys)
+ img, polys = self.TSM.random_flip(img, polys)
+ img, polys, classes,dontcare = self.TSM.random_crop_db_mul(img, polys,classes, dontcare)
+ img, gt, classes,gt_mask = self.MSM.process_mul(img, polys, classes,dontcare)
+ img, thresh_map, thresh_mask = self.MBM.process(img, polys, dontcare)
+
+ img = Image.fromarray(img).convert('RGB')
+ img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img)
+ img = self.TSM.normalize_img(img)
+
+
+ gt = torch.from_numpy(gt).float()
+ gt_classes = torch.from_numpy(classes).long()
+ gt_mask = torch.from_numpy(gt_mask).float()
+ thresh_map = torch.from_numpy(thresh_map).float()
+ thresh_mask = torch.from_numpy(thresh_mask).float()
+ return img,gt,gt_classes,gt_mask,thresh_map,thresh_mask
class DBProcessTest(data.Dataset):
def __init__(self,config):
diff --git a/ptocr/dataloader/DetLoad/MakeSegMap.py b/ptocr/dataloader/DetLoad/MakeSegMap.py
index 3a72c7d..02a8bff 100644
--- a/ptocr/dataloader/DetLoad/MakeSegMap.py
+++ b/ptocr/dataloader/DetLoad/MakeSegMap.py
@@ -27,6 +27,7 @@ def __init__(self, algorithm='DB',min_text_size = 8,shrink_ratio = 0.4,is_traini
self.shrink_ratio = shrink_ratio
self.is_training = is_training
self.algorithm = algorithm
+
def process(self, img,polys,dontcare):
'''
requied keys:
@@ -77,7 +78,61 @@ def process(self, img,polys,dontcare):
if self.algorithm == 'PAN':
return img,gt_text,gt_text_key,gt,gt_kernel_key,mask
return img,gt,mask
+
+ def process_mul(self, img,polys,classes,dontcare):
+ '''
+ requied keys:
+ image, polygons, ignore_tags, filename
+ adding keys:
+ mask
+ '''
+ h, w = img.shape[:2]
+ if self.is_training:
+ polys, dontcare = self.validate_polygons(
+ polys, dontcare, h, w)
+ gt = np.zeros((h, w), dtype=np.float32)
+ gt_class = np.zeros((h, w), dtype=np.float32)
+ mask = np.ones((h, w), dtype=np.float32)
+
+ if self.algorithm =='PAN':
+ gt_text = np.zeros((h, w), dtype=np.float32)
+ gt_text_key = np.zeros((h, w), dtype=np.float32)
+ gt_kernel_key = np.zeros((h, w), dtype=np.float32)
+ for i in range(len(polys)):
+ polygon = polys[i]
+ height = max(polygon[:, 1]) - min(polygon[:, 1])
+ width = max(polygon[:, 0]) - min(polygon[:, 0])
+ if dontcare[i] or min(height, width) < self.min_text_size:
+ cv2.fillPoly(mask, polygon.astype(
+ np.int32)[np.newaxis, :, :], 0)
+ dontcare[i] = True
+ else:
+ if self.algorithm == 'PAN':
+ cv2.fillPoly(gt_text, [polygon.astype(np.int32)], 1)
+ cv2.fillPoly(gt_text_key, [polygon.astype(np.int32)], i + 1)
+ polygon_shape = Polygon(polygon)
+ distance = polygon_shape.area * \
+ (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
+ subject = [tuple(l) for l in polys[i]]
+ padding = pyclipper.PyclipperOffset()
+ padding.AddPath(subject, pyclipper.JT_ROUND,
+ pyclipper.ET_CLOSEDPOLYGON)
+ shrinked = padding.Execute(-distance)
+ if shrinked == []:
+ cv2.fillPoly(mask, polygon.astype(
+ np.int32)[np.newaxis, :, :], 0)
+ dontcare[i] = True
+ continue
+ shrinked = np.array(shrinked[0]).reshape(-1, 2)
+ cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
+ cv2.fillPoly(gt_class, polygon.astype(np.int32)[np.newaxis, :, :], 1+classes[i])
+ if self.algorithm == 'PAN':
+ cv2.fillPoly(gt_kernel_key, [shrinked.astype(np.int32)], i + 1)
+ if self.algorithm == 'PAN':
+ return img,gt_text,gt_text_key,gt,gt_kernel_key,mask
+ return img,gt,gt_class,mask
+
def validate_polygons(self, polygons, ignore_tags, h, w):
'''
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
diff --git a/ptocr/dataloader/DetLoad/__pycache__/DBProcess.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/DBProcess.cpython-36.pyc
new file mode 100644
index 0000000..2035a23
Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/DBProcess.cpython-36.pyc differ
diff --git a/ptocr/dataloader/DetLoad/__pycache__/MakeBorderMap.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/MakeBorderMap.cpython-36.pyc
new file mode 100644
index 0000000..014c97a
Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/MakeBorderMap.cpython-36.pyc differ
diff --git a/ptocr/dataloader/DetLoad/__pycache__/MakeSegMap.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/MakeSegMap.cpython-36.pyc
new file mode 100644
index 0000000..aaafb5a
Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/MakeSegMap.cpython-36.pyc differ
diff --git a/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori.cpython-36.pyc
new file mode 100644
index 0000000..93c4b9f
Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori.cpython-36.pyc differ
diff --git a/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori1.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori1.cpython-36.pyc
new file mode 100644
index 0000000..e3eaeab
Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori1.cpython-36.pyc differ
diff --git a/ptocr/dataloader/DetLoad/__pycache__/__init__.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..1c81771
Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/dataloader/DetLoad/__pycache__/transform_img.cpython-36.pyc b/ptocr/dataloader/DetLoad/__pycache__/transform_img.cpython-36.pyc
new file mode 100644
index 0000000..333ac1e
Binary files /dev/null and b/ptocr/dataloader/DetLoad/__pycache__/transform_img.cpython-36.pyc differ
diff --git a/ptocr/dataloader/DetLoad/transform_img.py b/ptocr/dataloader/DetLoad/transform_img.py
index ce02ead..adbb77a 100644
--- a/ptocr/dataloader/DetLoad/transform_img.py
+++ b/ptocr/dataloader/DetLoad/transform_img.py
@@ -76,6 +76,38 @@ def process(self, img, polys, dont_care):
new_dotcare.append(dont_care[i])
return img, new_polys, new_dotcare
+
+ def process_mul(self, img, polys, classes, dont_care):
+ all_care_polys = []
+ for i in range(len(dont_care)):
+ if (dont_care[i] is False):
+ all_care_polys.append(polys[i])
+ crop_x, crop_y, crop_w, crop_h = self.crop_area(img, all_care_polys)
+ scale_w = self.size[0] / crop_w
+ scale_h = self.size[1] / crop_h
+ scale = min(scale_w, scale_h)
+ h = int(crop_h * scale)
+ w = int(crop_w * scale)
+ padimg = np.zeros(
+ (self.size[1], self.size[0], img.shape[2]), img.dtype)
+ padimg[:h, :w] = cv2.resize(
+ img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+ img = padimg
+
+ new_polys = []
+ new_dotcare = []
+ new_classes = []
+
+ for i in range(len(polys)):
+ poly = polys[i]
+ poly = ((np.array(poly) -
+ (crop_x, crop_y)) * scale)
+ if not self.is_poly_outside_rect(poly, 0, 0, w, h):
+ new_polys.append(poly)
+ new_dotcare.append(dont_care[i])
+ new_classes.append(classes[i])
+
+ return img, new_polys,new_classes, new_dotcare
def is_poly_in_rect(self, poly, x, y, w, h):
poly = np.array(poly)
@@ -259,6 +291,10 @@ def random_flip(self, img, polys):
def random_crop_db(self, img, polys, dont_care):
img, new_polys, new_dotcare = self.random_crop_data.process(img, polys, dont_care)
return img, new_polys, new_dotcare
+
+ def random_crop_db_mul(self, img, polys,classes, dont_care):
+ img, new_polys, new_classes,new_dotcare = self.random_crop_data.process_mul(img, polys, classes,dont_care)
+ return img, new_polys,new_classes, new_dotcare
def random_crop_pse(self, imgs, ):
h, w = imgs[0].shape[0:2]
diff --git a/ptocr/dataloader/RecLoad/CRNNProcess.py b/ptocr/dataloader/RecLoad/CRNNProcess.py
index 68a370a..fb32fe2 100644
--- a/ptocr/dataloader/RecLoad/CRNNProcess.py
+++ b/ptocr/dataloader/RecLoad/CRNNProcess.py
@@ -1,8 +1,14 @@
+import lmdb
import torch
-import torchvision.transforms as transforms
+import six,re,glob,os
+import numpy as np
+from PIL import Image,ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.utils.data import Dataset
-from .DataAgument import transform_image_add,transform_img_shape,transform_image_one
-from PIL import Image
+import torchvision.transforms as transforms
+from ptocr.dataloader.RecLoad.DataAgument import transform_img_shape,DataAugment
+from ptocr.utils.util_function import create_module,PILImageToCV,CVImageToPIL
+
def get_img(path,is_gray = False):
img = Image.open(path).convert('RGB')
@@ -10,54 +16,36 @@ def get_img(path,is_gray = False):
img = img.convert('L')
return img
-class CRNNProcessTrain(Dataset):
- def __init__(self, config):
- super(CRNNProcessTrain,self).__init__()
- with open(config['trainload']['train_file'],'r',encoding='utf-8') as fid:
- self.label_list = []
- self.image_list = []
-
- for line in fid.readlines():
- line = line.strip('\n').replace('\ufeff','').split('\t')
- self.label_list.append(line[1])
- self.image_list.append(line[0])
-
+class CRNNProcessLmdbLoad(Dataset):
+ def __init__(self, config,lmdb_type):
self.config = config
+ self.lmdb_type = lmdb_type
- def __len__(self):
- return len(self.label_list)
+ if lmdb_type=='train':
+ lmdb_file = config['trainload']['train_file']
+ workers = config['trainload']['num_workers']
+ elif lmdb_type == 'val':
+ lmdb_file = config['valload']['val_file']
+ workers = config['valload']['num_workers']
+ else:
+ assert 1 == 1
+ raise('lmdb_type error !!!')
- def __getitem__(self, index):
- img = get_img(self.image_list[index],is_gray=self.config['base']['is_gray'])
-# img = transform_image_add(img)
- img = transform_image_one(img)
- img = transform_img_shape(img,self.config['base']['img_shape'])
- img = Image.fromarray(img)
- img = transforms.ToTensor()(img)
- img.sub_(0.5).div_(0.5)
- label = self.label_list[index]
- return (img,label)
-
-
-class CRNNProcessTrainLmdb(Dataset):
- def __init__(self, config):
- self.env = lmdb.open(
- config['trainload']['train_file'],
- max_readers=1,
- readonly=True,
- lock=False,
- readahead=False,
- meminit=False)
+ self.env = lmdb.open(lmdb_file,max_readers=workers,readonly=True,lock=False,readahead=False,meminit=False)
if not self.env:
- print('cannot creat lmdb from %s' % (root))
+ print('cannot creat lmdb from %s' % (lmdb_file))
sys.exit(0)
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode('utf-8')))
self.nSamples = nSamples
- self.config = config
+ self.transform_label = create_module(config['label_transform']['label_function'])
+
+ self.bg_img = []
+ for path in glob.glob(os.path.join(config['trainload']['bg_path'], '*')):
+ self.bg_img.append(path)
def __len__(self):
return self.nSamples
@@ -79,22 +67,29 @@ def __getitem__(self, index):
return self[index + 1]
label_key = 'label-%09d' % index
- label = txn.get(label_key.encode('utf-8'))
-
+ label = txn.get(label_key.encode('utf-8')).decode()
+ label = self.transform_label(label, char_type=self.config['label_transform']['char_type'],t_type=self.config['label_transform']['t_type'])
if self.config['base']['is_gray']:
img = img.convert('L')
- img = np.array(img)
-
- img = transform_image_one(img)
- img = transform_img_shape(img,self.config['base']['img_shape'])
- img = Image.fromarray(img)
+ img = PILImageToCV(img,self.config['base']['is_gray'])
+ if self.lmdb_type == 'train':
+ try:
+ bg_index = np.random.randint(0, len(self.bg_img))
+ bg_img = PILImageToCV(get_img(self.bg_img[bg_index]),self.config['base']['is_gray'])
+ img = transform_img_shape(img, self.config['base']['img_shape'])
+ img = DataAugment(img, bg_img, self.config['base']['img_shape'])
+ img = transform_img_shape(img, self.config['base']['img_shape'])
+ except IOError:
+ print('Corrupted image for %d' % index)
+ return self[index + 1]
+ elif self.lmdb_type == 'val':
+ img = transform_img_shape(img, self.config['base']['img_shape'])
+ img = CVImageToPIL(img,self.config['base']['is_gray'])
img = transforms.ToTensor()(img)
img.sub_(0.5).div_(0.5)
-
- return (img, label)
-
+ return (img, label)
+
class alignCollate(object):
-
def __init__(self,):
pass
def __call__(self, batch):
@@ -102,28 +97,4 @@ def __call__(self, batch):
images = torch.stack(images,0)
return images,labels
-class CRNNProcessTest(Dataset):
- def __init__(self, config):
- super(CRNNProcessTest,self).__init__()
- with open(config['testload']['test_file'],'r',encoding='utf-8') as fid:
- self.label_list = []
- self.image_list = []
-
- for line in fid.readlines():
- line = line.strip('\n').replace('\ufeff','').split('\t')
- self.label_list.append(line[1])
- self.image_list.append(line[0])
-
- self.config = config
-
- def __len__(self):
- return len(self.label_list)
-
- def __getitem__(self, index):
- img = get_img(self.image_list[index],is_gray=self.config['base']['is_gray'])
- img = transform_img_shape(img,self.config['base']['img_shape'])
- img = Image.fromarray(img)
- img = transforms.ToTensor()(img)
- img.sub_(0.5).div_(0.5)
- label = self.label_list[index]
- return (img,label)
\ No newline at end of file
+
\ No newline at end of file
diff --git a/ptocr/dataloader/RecLoad/CRNNProcess1.py b/ptocr/dataloader/RecLoad/CRNNProcess1.py
new file mode 100644
index 0000000..de9c6ec
--- /dev/null
+++ b/ptocr/dataloader/RecLoad/CRNNProcess1.py
@@ -0,0 +1,244 @@
+#-*- coding:utf-8 _*-
+"""
+@author:fxw
+@file: CRNNProcess.py
+@time: 2021/03/23
+"""
+
+import lmdb
+import torch
+import six,re,glob,os
+import numpy as np
+from PIL import Image
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+from torch.utils.data import Dataset
+import torchvision.transforms as transforms
+from ptocr.dataloader.RecLoad.DataAgument import transform_img_shape,transform_image_one
+from ptocr.utils.util_function import create_module
+
+
+def get_img(path,is_gray = False):
+ try:
+ img = Image.open(path).convert('RGB')
+ except:
+ print(path)
+ img = np.zeros(3,320,32)
+ img = Image.fromarray(img).convert('RGB')
+ if(is_gray):
+ img = img.convert('L')
+ return img
+
+class alignCollate(object):
+ def __init__(self, ):
+ pass
+ def __call__(self, batch):
+ images, labels = zip(*batch)
+ images = torch.stack(images, 0)
+ return images, labels
+
+class CRNNProcessTrainLmdb(Dataset):
+ def __init__(self, config,lmdb_path):
+ self.env = lmdb.open(
+ lmdb_path,
+ max_readers=config['trainload']['num_workers'],
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False)
+
+ with self.env.begin(write=False) as txn:
+ nSamples = int(txn.get('num-samples'.encode('utf-8')))
+ self.nSamples = nSamples
+
+ self.config = config
+ self.transform_label = create_module(config['transform']['function'])
+ self.bg_img = []
+ for path in glob.glob(os.path.join(config['trainload']['bg_path'],'*')):
+ self.bg_img.append(path)
+
+
+ def __len__(self):
+ return self.nSamples
+
+ def __getitem__(self, index):
+ assert index <= len(self), 'index range error'
+ index += 1
+ with self.env.begin(write=False) as txn:
+
+ img_key = 'image-%09d' % index
+ imgbuf = txn.get(img_key.encode('utf-8'))
+ buf = six.BytesIO()
+ buf.write(imgbuf)
+ buf.seek(0)
+
+ try:
+ img = Image.open(buf).convert('RGB')
+ except IOError:
+ print('IO image for %d' % index)
+ return self[index + 1]
+
+ label_key = 'label-%09d' % index
+ label = txn.get(label_key.encode('utf-8'))
+
+ if self.config['base']['is_gray']:
+ img = img.convert('L')
+
+ img = np.array(img)
+ if isinstance(label,bytes):
+ label = label.decode()
+
+ label = self.transform_label(label,char_type=self.config['transform']['char_type'],t_type=self.config['transform']['t_type'])
+
+
+ try:
+ bg_index = np.random.randint(0,len(self.bg_img))
+ bg_img = np.array(get_img(self.bg_img[bg_index]))
+ img = transform_img_shape(img,self.config['base']['img_shape'])
+ img = transform_image_one(img,bg_img,self.config['base']['img_shape'])
+ img = transform_img_shape(img,self.config['base']['img_shape'])
+ except:
+ print('Corrupted image for %d' % index)
+ img = np.zeros((32,100)).astype(np.uint8)
+
+ img = Image.fromarray(img)
+ img = transforms.ToTensor()(img)
+ img.sub_(0.5).div_(0.5)
+
+ return (img, label)
+
+def GetDataLoad(config,data_type='train'):
+ if data_type == 'train':
+ lmdb_path_list = config['trainload']['train_file']
+ ratio = config['trainload']['batch_ratio']
+ elif data_type == 'val':
+ lmdb_path_list = config['valload']['val_file']
+
+ num = len(lmdb_path_list)
+ train_data_loaders = []
+ sum_num = 0
+ for i in range(num):
+ if data_type=='train':
+ if i==num-1:
+ batch_size = config['trainload']['batch_size'] - sum_num
+ else:
+ batch_size = int(config['trainload']['batch_size']*ratio[i])//2*2
+ sum_num+=batch_size
+ dataset = CRNNProcessTrainLmdb(config, lmdb_path_list[i])
+ num_workers = config['trainload']['num_workers']
+ shuffle = True
+ elif data_type == 'val':
+ batch_size = 1
+ num_workers = config['valload']['num_workers']
+ dataset = CRNNProcessTrainLmdb(config, lmdb_path_list[i])
+ shuffle = False
+
+ train_data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ num_workers=num_workers,
+ collate_fn=alignCollate(),
+ drop_last=True,
+ pin_memory=True)
+ train_data_loaders.append(train_data_loader)
+ return train_data_loaders
+
+def GetValDataLoad(config):
+
+ val_dir = config['valload']['dir']
+ root= config['valload']['root']
+
+ num = len(val_dir)
+ data_loaders = []
+ sum_num = 0
+ for i in range(num):
+
+ batch_size = 1
+ num_workers = config['valload']['num_workers']
+ config['valload']['test_file'] = os.path.join(root,val_dir[i],'val_train.txt')
+ dataset = CRNNProcessTest(config)
+ shuffle = False
+
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ num_workers=num_workers,
+ collate_fn=alignCollate(),
+ drop_last=True,
+ pin_memory=True)
+ data_loaders.append(data_loader)
+ return data_loaders
+
+class CRNNProcessTest(Dataset):
+ def __init__(self, config):
+ super(CRNNProcessTest,self).__init__()
+ with open(config['valload']['test_file'],'r',encoding='utf-8') as fid:
+ self.label_list = []
+ self.image_list = []
+
+ for line in fid.readlines():
+ line = line.strip('\n').replace('\ufeff','').split('\t')
+ self.label_list.append(line[1])
+ self.image_list.append(line[0])
+
+ self.config = config
+
+ def __len__(self):
+ return len(self.label_list)
+
+ def __getitem__(self, index):
+ img = get_img(self.image_list[index],is_gray=self.config['base']['is_gray'])
+ img = transform_img_shape(img,self.config['base']['img_shape'])
+ img = Image.fromarray(img)
+ img = transforms.ToTensor()(img)
+ img.sub_(0.5).div_(0.5)
+ label = self.label_list[index]
+ return (img,label)
+
+
+# if __name__ == "__main__":
+
+# config = {}
+# config['base'] = {}
+# config['base']['img_shape'] = [32,100]
+# config['trainload'] = {}
+# config['valload'] = {}
+# config['trainload']['bg_path']='./bg_img/'
+# config['base']['is_gray'] = True
+# config['trainload']['train_file'] =['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth']
+
+# config['valload']['val_file'] = ['D:\BaiduNetdiskDownload\data\evaluation\CUTE80',
+# 'D:\BaiduNetdiskDownload\data\evaluation\IC03_860',
+# 'D:\BaiduNetdiskDownload\data\evaluation\IC03_867']
+# config['trainload']['num_workers'] = 0
+# config['valload']['num_workers'] = 0
+# config['trainload']['batch_size'] = 128
+# config['trainload']['batch_ratio'] = [0.33, 0.33, 0.33]
+# config['transform']={}
+
+# config['transform']['char_type'] = 'En'
+# config['transform']['t_type']='lower'
+# config['transform']['function'] = 'ptocr.dataloader.RecLoad.DataAgument,transform_label'
+# import time
+# train_data_loaders = GetDataLoad(config,data_type='train')
+# # import pdb
+# # pdb.set_trace()
+
+# # t = time.time()
+# # for idx,data1 in enumerate(train_data_loaders[0]):
+# # pass
+# # print(time.time()-t)
+
+# data1 = enumerate(train_data_loaders[0])
+# for i in range(100):
+# try:
+# t = time.time()
+# index,(data,label) = next(data1)
+# print(time.time()-t)
+# print(label)
+# except:
+# print('end')
+
+
\ No newline at end of file
diff --git a/ptocr/dataloader/RecLoad/DataAgument.py b/ptocr/dataloader/RecLoad/DataAgument.py
index 1e797ce..c9af073 100644
--- a/ptocr/dataloader/RecLoad/DataAgument.py
+++ b/ptocr/dataloader/RecLoad/DataAgument.py
@@ -4,7 +4,7 @@
@file: DataAgument.py
@time: 2019/06/06
"""
-import cv2
+import cv2,re
import numpy as np
from skimage.util import random_noise
import os
@@ -113,28 +113,16 @@ def perform_operation(self, images):
dy = random.randint(-self.magnitude, self.magnitude)
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a]
- polygons[a] = [x1, y1,
- x2, y2,
- x3 + dx, y3 + dy,
- x4, y4]
+ polygons[a] = [x1, y1,x2, y2,x3 + dx, y3 + dy,x4, y4]
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b]
- polygons[b] = [x1, y1,
- x2 + dx, y2 + dy,
- x3, y3,
- x4, y4]
+ polygons[b] = [x1, y1, x2 + dx, y2 + dy,x3, y3,x4, y4]
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c]
- polygons[c] = [x1, y1,
- x2, y2,
- x3, y3,
- x4 + dx, y4 + dy]
+ polygons[c] = [x1, y1,x2, y2,x3, y3,x4 + dx, y4 + dy]
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d]
- polygons[d] = [x1 + dx, y1 + dy,
- x2, y2,
- x3, y3,
- x4, y4]
+ polygons[d] = [x1 + dx, y1 + dy,x2, y2,x3, y3,x4, y4]
generated_mesh = []
for i in range(len(dimensions)):
@@ -306,7 +294,7 @@ def transform_img_shape(img, img_shape):
img = np.array(img)
H, W = img_shape
h, w = img.shape[:2]
- new_w = int((float(H)/ h) * w)
+ new_w = int((float(H)/h) * w)
if (new_w > W):
img = cv2.resize(img, (W, H))
else:
@@ -328,11 +316,23 @@ def random_crop(image):
crop_img = crop_img[0:h - top_crop, :]
return crop_img
-def transform_image_one(image):
+def get_background_Amg(img,bg_img,img_shape=[32,200]):
+ H,W = img_shape
+ x = np.random.randint(0, img.shape[0] - H + 1)
+ y = np.random.randint(0, img.shape[1] - W + 1)
+ img_crop = bg_img[x:x + H, y:y + W]
+ ratio = np.random.randint(1,4)/10.0+np.random.randint(0,10)/100.0
+# print(img.shape,img_crop.shape)
+ img = np.array(Image.fromarray(img).convert('RGB'))
+ dst = cv2.addWeighted(img, 1-ratio, img_crop, ratio, 2)
+ dst = np.array(Image.fromarray(dst).convert('L'))
+ return dst
+
+def DataAugment(image,bg_img,img_shape):
image = np.array(image)
if (np.random.choice([True, False], 1)[0]):
dataAu = DataAugmentatonMore(image)
- index = np.random.randint(0,10)
+ index = np.random.randint(0,11)
if (index == 0):
degree = np.random.randint(2, 6)
angle = np.random.randint(0, 360)
@@ -367,62 +367,24 @@ def transform_image_one(image):
image = RandomAddLine(image)
elif(index==9):
image = random_crop(image)
- del dataAu
+ elif(index==10):
+ image = get_background_Amg(image,bg_img,img_shape)
+ del dataAu,bg_img
return image
-def transform_image_add(image):
- image = np.array(image)
- is_transform = np.random.choice([True, False], 1)[0]
- # image = RandomAddLine(image)
- if (is_transform):
- dataAu = DataAugmentatonMore(image)
- is_transform = np.random.choice([True, False], 1)[0]
- if (is_transform):
- image = dataAu.image
- image = random_dilute(image)
- dataAu.image = image
- is_transform = np.random.choice([True, False], 1)[0]
- if (is_transform):
- image = dataAu.image
- image = Image.fromarray(image)
- image = GetRandomDistortImage([image])[0]
- dataAu.image = image
- is_transform = np.random.choice([True, False], 1)[0]
- if (is_transform):
- degree = np.random.randint(2, 8)
- angle = np.random.randint(0, 360)
- image = dataAu.motion_blur(degree, angle)
- dataAu.image = image
- is_transform = np.random.choice([True, False], 1)[0]
- if (is_transform):
- id = np.random.randint(0, 3)
- k_size = [3, 5, 7]
- image = dataAu.gaussian_blur(k_size[id])
- dataAu.image = image
- is_transform = np.random.choice([True, False], 1)[0]
- if (is_transform):
- alpha = np.random.uniform(0.6, 1.3)
- image = dataAu.Contrast_and_Brightness(alpha)
- dataAu.image = image
- is_transform = np.random.choice([True, False], 1)[0]
- if (is_transform):
- types = ['top', 'botttom']
- id = np.random.randint(0, 1)
- ratio = np.random.randint(0, 5)
- image = dataAu.Perspective(ratio, types[id])
- dataAu.image = image
- is_transform = np.random.choice([True, False], 1)[0]
- if (is_transform and image.shape[0] > 20):
- ratio = np.random.uniform(0.45, 0.7)
- image = dataAu.resize_blur(ratio)
- dataAu.image = image
- del dataAu
- return image
-
-
-
+def transform_label(label,char_type='En',t_type = 'lower'):
+ if char_type == 'En':
+ if t_type == 'lower':
+ label = ''.join(re.findall('[0-9a-zA-Z]+', label)).lower()
+ elif t_type == 'upper':
+ label = ''.join(re.findall('[0-9a-zA-Z]+', label)).upper()
+ else:
+ label = ''.join(re.findall('[0-9a-zA-Z]+', label))
+ elif(char_type=='Ch'):
+ return label
+ return label
diff --git a/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess.cpython-36.pyc b/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess.cpython-36.pyc
new file mode 100644
index 0000000..2328be6
Binary files /dev/null and b/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess.cpython-36.pyc differ
diff --git a/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess1.cpython-36.pyc b/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess1.cpython-36.pyc
new file mode 100644
index 0000000..9fb131a
Binary files /dev/null and b/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess1.cpython-36.pyc differ
diff --git a/ptocr/dataloader/RecLoad/__pycache__/DataAgument.cpython-36.pyc b/ptocr/dataloader/RecLoad/__pycache__/DataAgument.cpython-36.pyc
new file mode 100644
index 0000000..df27415
Binary files /dev/null and b/ptocr/dataloader/RecLoad/__pycache__/DataAgument.cpython-36.pyc differ
diff --git a/ptocr/dataloader/RecLoad/__pycache__/__init__.cpython-36.pyc b/ptocr/dataloader/RecLoad/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..093e441
Binary files /dev/null and b/ptocr/dataloader/RecLoad/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/dataloader/__pycache__/__init__.cpython-36.pyc b/ptocr/dataloader/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..92fbd37
Binary files /dev/null and b/ptocr/dataloader/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/model/__pycache__/CommonFunction.cpython-36.pyc b/ptocr/model/__pycache__/CommonFunction.cpython-36.pyc
new file mode 100644
index 0000000..cb2c691
Binary files /dev/null and b/ptocr/model/__pycache__/CommonFunction.cpython-36.pyc differ
diff --git a/ptocr/model/__pycache__/__init__.cpython-36.pyc b/ptocr/model/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..5211216
Binary files /dev/null and b/ptocr/model/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc b/ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc
index 31b4481..6b35176 100644
Binary files a/ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc b/ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc
index 1c97390..1f3381f 100644
Binary files a/ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc and b/ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc differ
diff --git a/ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc b/ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc
index 608dd16..ede5104 100644
Binary files a/ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc and b/ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc differ
diff --git a/ptocr/model/architectures/__pycache__/stn.cpython-36.pyc b/ptocr/model/architectures/__pycache__/stn.cpython-36.pyc
new file mode 100644
index 0000000..ceb29c2
Binary files /dev/null and b/ptocr/model/architectures/__pycache__/stn.cpython-36.pyc differ
diff --git a/ptocr/model/architectures/__pycache__/stn_head.cpython-36.pyc b/ptocr/model/architectures/__pycache__/stn_head.cpython-36.pyc
new file mode 100644
index 0000000..62560ad
Binary files /dev/null and b/ptocr/model/architectures/__pycache__/stn_head.cpython-36.pyc differ
diff --git a/ptocr/model/architectures/__pycache__/tps_spatial_transformer.cpython-36.pyc b/ptocr/model/architectures/__pycache__/tps_spatial_transformer.cpython-36.pyc
new file mode 100644
index 0000000..e697196
Binary files /dev/null and b/ptocr/model/architectures/__pycache__/tps_spatial_transformer.cpython-36.pyc differ
diff --git a/ptocr/model/architectures/det_model.py b/ptocr/model/architectures/det_model.py
index 0f91513..78f7337 100644
--- a/ptocr/model/architectures/det_model.py
+++ b/ptocr/model/architectures/det_model.py
@@ -20,9 +20,15 @@ def __init__(self, config):
self.head = create_module(config['head']['function']) \
(config['base']['in_channels'],
config['base']['inner_channels'])
-
+ self.mulclass = False
if (config['base']['algorithm']) == 'DB':
- self.seg_out = create_module(config['segout']['function'])(config['base']['inner_channels'],
+ if 'n_class' in config['base'].keys():
+ self.mulclass = True
+ self.seg_out = create_module(config['segout']['function'])(config['base']['n_class'],config['base']['inner_channels'],
+ config['base']['k'],
+ config['base']['adaptive'])
+ else:
+ self.seg_out = create_module(config['segout']['function'])(config['base']['inner_channels'],
config['base']['k'],
config['base']['adaptive'])
elif (config['base']['algorithm']) == 'PAN':
@@ -40,14 +46,21 @@ def __init__(self, config):
def forward(self, data):
if self.training:
if self.algorithm == "DB":
- img, gt, gt_mask, thresh_map, thresh_mask = data
+ if self.mulclass:
+ img, gt,gt_class, gt_mask, thresh_map, thresh_mask = data
+ else:
+ img, gt, gt_mask, thresh_map, thresh_mask = data
if torch.cuda.is_available():
+ if self.mulclass:
+ gt_class = gt_class.cuda()
img, gt, gt_mask, thresh_map, thresh_mask = \
img.cuda(), gt.cuda(), gt_mask.cuda(), thresh_map.cuda(), thresh_mask.cuda()
gt_batch = dict(gt=gt)
gt_batch['mask'] = gt_mask
gt_batch['thresh_map'] = thresh_map
gt_batch['thresh_mask'] = thresh_mask
+ if self.mulclass:
+ gt_batch['gt_class'] = gt_class
elif self.algorithm == "PSE":
img, gt_text, gt_kernels, train_mask = data
@@ -98,8 +111,12 @@ def __init__(self, config):
super(DetLoss, self).__init__()
self.algorithm = config['base']['algorithm']
if (config['base']['algorithm']) == 'DB':
- self.loss = create_module(config['loss']['function'])(config['loss']['l1_scale'],
- config['loss']['bce_scale'])
+ if 'n_class' in config['base'].keys():
+ self.loss = create_module(config['loss']['function'])(config['base']['n_class'],config['loss']['l1_scale'],
+ config['loss']['bce_scale'],config['loss']['class_scale'])
+ else:
+ self.loss = create_module(config['loss']['function'])(config['loss']['l1_scale'],config['loss']['bce_scale'])
+
elif (config['base']['algorithm']) == 'PAN':
self.loss = create_module(config['loss']['function'])(config['loss']['kernel_rate'],
config['loss']['agg_dis_rate'])
diff --git a/ptocr/model/architectures/paddle_tps.py b/ptocr/model/architectures/paddle_tps.py
new file mode 100644
index 0000000..c1f5a44
--- /dev/null
+++ b/ptocr/model/architectures/paddle_tps.py
@@ -0,0 +1,252 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+
+class ConvBNLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_relu=False):
+ super(ConvBNLayer, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ )
+ self.bn = nn.BatchNorm2d(
+ out_channels)
+ if is_relu:
+ self.relu = nn.ReLU()
+ self.is_relu = is_relu
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.is_relu:
+ x = self.relu(x)
+ return x
+
+
+class LocalizationNetwork(nn.Module):
+ def __init__(self, in_channels, num_fiducial, model_name):
+ super(LocalizationNetwork, self).__init__()
+ self.F = num_fiducial
+ F = num_fiducial
+ if model_name == "large":
+ num_filters_list = [64, 128, 256, 512]
+ fc_dim = 256
+ else:
+ num_filters_list = [16, 32, 64, 128]
+ fc_dim = 64
+
+ block_list = []
+ for fno in range(0, len(num_filters_list)):
+ num_filters = num_filters_list[fno]
+ conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=num_filters,
+ kernel_size=3,
+ is_relu=True)
+ block_list.append(conv)
+ if fno == len(num_filters_list) - 1:
+ pool = nn.AdaptiveAvgPool2d(1)
+ else:
+ pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ in_channels = num_filters
+ block_list.append(pool)
+ self.fc1 = nn.Linear(
+ in_channels,
+ fc_dim)
+
+ # Init fc2 in LocalizationNetwork
+ self.fc2 = nn.Linear(
+ fc_dim,
+ F * 2)
+ self.out_channels = F * 2
+ self.block_list = nn.Sequential(*block_list)
+
+ def forward(self, x):
+ """
+ Estimating parameters of geometric transformation
+ Args:
+ image: input
+ Return:
+ batch_C_prime: the matrix of the geometric transformation
+ """
+ B = x.shape[0]
+ i = 0
+ for block in self.block_list:
+ x = block(x)
+ x = x.squeeze(2).squeeze(2)
+ x = self.fc1(x)
+
+ x = F.relu(x)
+ x = self.fc2(x)
+ x = x.reshape(shape=[-1, self.F, 2])
+ return x
+
+ def get_initial_fiducials(self):
+ """ see RARE paper Fig. 6 (a) """
+ F = self.F
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
+ ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
+ ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ return initial_bias
+
+
+class GridGenerator(nn.Module):
+ def __init__(self, in_channels, num_fiducial):
+ super(GridGenerator, self).__init__()
+ self.eps = 1e-6
+ self.F = num_fiducial
+ self.fc = nn.Linear(
+ in_channels,
+ 6)
+
+ def forward(self, batch_C_prime, I_r_size):
+ """
+ Generate the grid for the grid_sampler.
+ Args:
+ batch_C_prime: the matrix of the geometric transformation
+ I_r_size: the shape of the input image
+ Return:
+ batch_P_prime: the grid for the grid_sampler
+ """
+ C = self.build_C_paddle()
+ P = self.build_P_paddle(I_r_size)
+
+ inv_delta_C_tensor = self.build_inv_delta_C_paddle(C)
+ P_hat_tensor = self.build_P_hat_paddle(
+ C, torch.Tensor(P))
+
+ inv_delta_C_tensor.stop_gradient = True
+ P_hat_tensor.stop_gradient = True
+
+ batch_C_ex_part_tensor = self.get_expand_tensor(batch_C_prime)
+
+ batch_C_ex_part_tensor.stop_gradient = True
+
+ batch_C_prime_with_zeros = torch.cat([batch_C_prime, batch_C_ex_part_tensor],1)
+ batch_T = torch.matmul(inv_delta_C_tensor, batch_C_prime_with_zeros)
+ batch_P_prime = torch.matmul(P_hat_tensor, batch_T)
+ return batch_P_prime
+
+ def build_C_paddle(self):
+ """ Return coordinates of fiducial points in I_r; C """
+ F = self.F
+ ctrl_pts_x = torch.linspace(-1.0, 1.0, int(F / 2))
+ ctrl_pts_y_top = -1 * torch.ones([int(F / 2)])
+ ctrl_pts_y_bottom = torch.ones([int(F / 2)])
+ ctrl_pts_top = torch.stack([ctrl_pts_x, ctrl_pts_y_top], 1)
+ ctrl_pts_bottom = torch.stack([ctrl_pts_x, ctrl_pts_y_bottom],1)
+ C = torch.cat([ctrl_pts_top, ctrl_pts_bottom], 0)
+ return C # F x 2
+
+ def build_P_paddle(self, I_r_size):
+ I_r_height, I_r_width = I_r_size
+ I_r_grid_x = (torch.arange(
+ -I_r_width, I_r_width, 2) + 1.0
+ ) / torch.Tensor(np.array([I_r_width]))
+
+ I_r_grid_y = (torch.arange(
+ -I_r_height, I_r_height, 2) + 1.0
+ ) / torch.Tensor(np.array([I_r_height]))
+
+ # P: self.I_r_width x self.I_r_height x 2
+ P = torch.stack(torch.meshgrid(I_r_grid_x, I_r_grid_y), 2)
+ P = P.permute(1, 0, 2)
+ # n (= self.I_r_width x self.I_r_height) x 2
+ return P.reshape([-1, 2])
+
+ def build_inv_delta_C_paddle(self, C):
+ """ Return inv_delta_C which is needed to calculate T """
+ F = self.F
+ hat_C = torch.zeros((F, F)) # F x F
+ for i in range(0, F):
+ for j in range(i, F):
+ if i == j:
+ hat_C[i, j] = 1
+ else:
+ r = torch.norm(C[i] - C[j])
+ hat_C[i, j] = r
+ hat_C[j, i] = r
+ hat_C = (hat_C**2) * torch.log(hat_C)
+ delta_C = torch.cat( # F+3 x F+3
+ [
+ torch.cat(
+ [torch.ones(
+ (F, 1)), C, hat_C], 1), # F x F+3
+ torch.cat(
+ [
+ torch.zeros(
+ (2, 3)), torch.transpose(
+ C, 1, 0)
+ ],1), # 2 x F+3
+ torch.cat(
+ [
+ torch.zeros(
+ (1, 3)), torch.ones(
+ (1, F))
+ ],1) # 1 x F+3
+],0)
+ inv_delta_C = torch.inverse(delta_C)
+ return inv_delta_C # F+3 x F+3
+
+ def build_P_hat_paddle(self, C, P):
+ F = self.F
+ eps = self.eps
+ n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
+ # P_tile: n x 2 -> n x 1 x 2 -> n x F x 2
+ P_tile = P.unsqueeze(1).repeat(1, F, 1)
+ C_tile = torch.unsqueeze(C, 0) # 1 x F x 2
+ P_diff = P_tile - C_tile # n x F x 2
+ # rbf_norm: n x F
+ rbf_norm = torch.norm(P_diff, p=2, dim=2, keepdim=False)
+
+ # rbf: n x F
+ rbf = torch.mul(
+ torch.square(rbf_norm), torch.log(rbf_norm + eps))
+ P_hat = torch.cat(
+ [torch.ones(
+ (n, 1)), P, rbf],1)
+ return P_hat # n x F+3
+
+ def get_expand_tensor(self, batch_C_prime):
+ B, H, C = batch_C_prime.shape
+ batch_C_prime = batch_C_prime.reshape([B, H * C])
+ batch_C_ex_part_tensor = self.fc(batch_C_prime)
+ batch_C_ex_part_tensor = batch_C_ex_part_tensor.reshape([-1, 3, 2])
+ return batch_C_ex_part_tensor
+
+
+class TPS(nn.Module):
+ def __init__(self, in_channels, num_fiducial, model_name):
+ super(TPS, self).__init__()
+ self.loc_net = LocalizationNetwork(in_channels, num_fiducial,
+ model_name)
+ self.grid_generator = GridGenerator(self.loc_net.out_channels,
+ num_fiducial)
+ self.out_channels = in_channels
+
+ def forward(self, image):
+ batch_C_prime = self.loc_net(image)
+ batch_P_prime = self.grid_generator(batch_C_prime, image.shape[2:])
+ batch_P_prime = batch_P_prime.reshape(
+ [-1, image.shape[2], image.shape[3], 2])
+ batch_I_r = F.grid_sample(image, grid=batch_P_prime,align_corners=True)
+ return batch_I_r
+
diff --git a/ptocr/model/architectures/rec_model.py b/ptocr/model/architectures/rec_model.py
index 9df17b2..9f5c2d4 100644
--- a/ptocr/model/architectures/rec_model.py
+++ b/ptocr/model/architectures/rec_model.py
@@ -6,27 +6,107 @@
"""
import torch
import torch.nn as nn
+import torch.nn.functional as F
from .. import create_module
+import cv2
class RecModel(nn.Module):
def __init__(self, config):
super(RecModel, self).__init__()
- self.algorithm = config['base']['algorithm']
+ self.algorithm = config['base']['algorithm']
+
self.backbone = create_module(config['backbone']['function'])(config['base']['pretrained'],config['base']['is_gray'])
- self.head = create_module(config['head']['function'])(
- use_conv=config['base']['use_conv'],
- use_attention=config['base']['use_attention'],
- use_lstm=config['base']['use_lstm'],
- lstm_num=config['base']['lstm_num'],
- inchannel=config['base']['inchannel'],
- hiddenchannel=config['base']['hiddenchannel'],
- classes=config['base']['classes'])
+ if self.algorithm=='CRNN':
+ self.head = create_module(config['head']['function'])(
+ use_attention=config['base']['use_attention'],
+ use_lstm=config['base']['use_lstm'],
+ time_step=config['base']['img_shape'][1]//4,
+ lstm_num=config['base']['lstm_num'],
+ inchannel=config['base']['inchannel'],
+ hiddenchannel=config['base']['hiddenchannel'],
+ classes=config['base']['classes'])
+ elif self.algorithm=='FC':
+ self.head = create_module(config['head']['function'])(in_channels=config['base']['in_channels'],
+ out_channels=config['base']['out_channels'],
+ max_length = config['base']['max_length'],
+ num_class = config['base']['num_class'])
+
+ def forward(self, x):
+ x1 = self.backbone(x)
+ x1,feau = self.head(x1)
+ return x1,feau
+
+# class RecModel(nn.Module):
+# def __init__(self, config):
+# super(RecModel, self).__init__()
+# self.algorithm = config['base']['algorithm']
+# if config['base']['is_gray']:
+# in_planes = 1
+# else:
+# in_planes = 3
+
+# self.backbone = create_module(config['backbone']['function'])(config['base']['pretrained'],config['base']['is_gray'])
+# self.head = create_module(config['head']['function'])(
+# use_conv=config['base']['use_conv'],
+# use_attention=config['base']['use_attention'],
+# use_lstm=config['base']['use_lstm'],
+# lstm_num=config['base']['lstm_num'],
+# inchannel=config['base']['inchannel'],
+# hiddenchannel=config['base']['hiddenchannel'],
+# classes=config['base']['classes'])
+
+# self.stn_head = create_module(config['stn']['function'])(in_planes,config)
+
+# def forward(self, x):
+# cv2.imwrite('stn_ori1.jpg',(x[0,0].cpu().detach().numpy()*0.5+0.5)*255)
+
+# x1= self.stn_head(x)
+# cv2.imwrite('stn1.jpg',(x1[0,0].cpu().detach().numpy()*0.5+0.5)*255)
+
+# x1 = self.backbone(x1)
+# x1 = self.head(x1)
+# return x1
+
+# class RecModel(nn.Module):
+# def __init__(self, config):
+# super(RecModel, self).__init__()
+# self.algorithm = config['base']['algorithm']
+# self.tps_inputsize = config['stn']['tps_inputsize']
+# if config['base']['is_gray']:
+# in_planes = 1
+# else:
+# in_planes = 3
+
+# self.backbone = create_module(config['backbone']['function'])(config['base']['pretrained'],config['base']['is_gray'])
+# self.head = create_module(config['head']['function'])(
+# use_conv=config['base']['use_conv'],
+# use_attention=config['base']['use_attention'],
+# use_lstm=config['base']['use_lstm'],
+# lstm_num=config['base']['lstm_num'],
+# inchannel=config['base']['inchannel'],
+# hiddenchannel=config['base']['hiddenchannel'],
+# classes=config['base']['classes'])
+
+# self.tps = create_module(config['stn']['t_function'])( output_image_size=tuple(config['stn']['tps_outputsize']),
+# num_control_points=config['stn']['num_control_points'],
+# margins=tuple(config['stn']['tps_margins']))
+# self.stn_head = create_module(config['stn']['function'])(in_planes=in_planes,
+# num_ctrlpoints=config['stn']['num_control_points'],
+# activation=config['stn']['stn_activation'])
+
+# def forward(self, x):
+# cv2.imwrite('stn_ori.jpg',(x[0,0].cpu().detach().numpy()*0.5+0.5)*255)
+# stn_input = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True)
+# stn_img_feat, ctrl_points = self.stn_head(stn_input)
- def forward(self, img):
- x = self.backbone(img)
- x = self.head(x)
- return x
+# x1, _ = self.tps(x, ctrl_points)
+
+# cv2.imwrite('stn.jpg',(x1[0,0].cpu().detach().numpy()*0.5+0.5)*255)
+
+# x1 = self.backbone(x1)
+# x1 = self.head(x1)
+# return x1
class RecLoss(nn.Module):
@@ -35,6 +115,8 @@ def __init__(self, config):
self.algorithm = config['base']['algorithm']
if (config['base']['algorithm']) == 'CRNN':
self.loss = create_module(config['loss']['function'])(config)
+ elif self.algorithm=='FC':
+ self.loss = create_module(config['loss']['function'])(ignore_index = config['base']['ignore_index'])
else:
assert True == False, ('not support this algorithm !!!')
diff --git a/ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc b/ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc
index a34ccd4..1f5e306 100644
Binary files a/ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc
index bc7d5fc..1a840d5 100644
Binary files a/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/det_mobilev3_dcd.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_mobilev3_dcd.cpython-36.pyc
new file mode 100644
index 0000000..24f3a9d
Binary files /dev/null and b/ptocr/model/backbone/__pycache__/det_mobilev3_dcd.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc
index 58ed83d..cd01ea5 100644
Binary files a/ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc
index bf5e35e..6998a2a 100644
Binary files a/ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc
index c3ad5ee..cc86e8b 100644
Binary files a/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc b/ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc
index 705ca8f..6786036 100644
Binary files a/ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/rec_crnn_backbone.cpython-36.pyc b/ptocr/model/backbone/__pycache__/rec_crnn_backbone.cpython-36.pyc
new file mode 100644
index 0000000..89a161d
Binary files /dev/null and b/ptocr/model/backbone/__pycache__/rec_crnn_backbone.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/rec_mobilev3_bd.cpython-36.pyc b/ptocr/model/backbone/__pycache__/rec_mobilev3_bd.cpython-36.pyc
new file mode 100644
index 0000000..cfff19c
Binary files /dev/null and b/ptocr/model/backbone/__pycache__/rec_mobilev3_bd.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc b/ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc
index abe0e48..b16abd4 100644
Binary files a/ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc and b/ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/__pycache__/reg_resnet_bd.cpython-36.pyc b/ptocr/model/backbone/__pycache__/reg_resnet_bd.cpython-36.pyc
new file mode 100644
index 0000000..3fc663a
Binary files /dev/null and b/ptocr/model/backbone/__pycache__/reg_resnet_bd.cpython-36.pyc differ
diff --git a/ptocr/model/backbone/det_densenet.py b/ptocr/model/backbone/det_densenet.py
deleted file mode 100644
index 4bf2353..0000000
--- a/ptocr/model/backbone/det_densenet.py
+++ /dev/null
@@ -1,242 +0,0 @@
-#-*- coding:utf-8 _*-
-"""
-@author:fxw
-@file: det_densenet.py
-@time: 2020/08/07
-"""
-import re
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.model_zoo as model_zoo
-from collections import OrderedDict
-
-__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
-
-
-model_urls = {
- 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
- 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
- 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
- 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
-}
-
-
-class _DenseLayer(nn.Sequential):
- def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
- super(_DenseLayer, self).__init__()
- self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
- self.add_module('relu1', nn.ReLU(inplace=True)),
- self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
- growth_rate, kernel_size=1, stride=1, bias=False)),
- self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
- self.add_module('relu2', nn.ReLU(inplace=True)),
- self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
- kernel_size=3, stride=1, padding=1, bias=False)),
- self.drop_rate = drop_rate
-
- def forward(self, x):
- new_features = super(_DenseLayer, self).forward(x)
- if self.drop_rate > 0:
- new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
- return torch.cat([x, new_features], 1)
-
-
-class _DenseBlock(nn.Sequential):
- def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
- super(_DenseBlock, self).__init__()
- for i in range(num_layers):
- layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
- self.add_module('denselayer%d' % (i + 1), layer)
-
-
-class _Transition(nn.Sequential):
- def __init__(self, num_input_features, num_output_features):
- super(_Transition, self).__init__()
- self.add_module('norm', nn.BatchNorm2d(num_input_features))
- self.add_module('relu', nn.ReLU(inplace=True))
- self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
- kernel_size=1, stride=1, bias=False))
- self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
-
-
-class DenseNet(nn.Module):
- r"""Densenet-BC model class, based on
- `"Densely Connected Convolutional Networks" `_
-
- Args:
- growth_rate (int) - how many filters to add each layer (`k` in paper)
- block_config (list of 4 ints) - how many layers in each pooling block
- num_init_features (int) - the number of filters to learn in the first convolution layer
- bn_size (int) - multiplicative factor for number of bottle neck layers
- (i.e. bn_size * k features in the bottleneck layer)
- drop_rate (float) - dropout rate after each dense layer
- num_classes (int) - number of classification classes
- """
-
- def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
- num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
-
- super(DenseNet, self).__init__()
-
- # First convolution
- self.features = nn.Sequential(OrderedDict([
- ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
- ('norm0', nn.BatchNorm2d(num_init_features)),
- ('relu0', nn.ReLU(inplace=True)),
- ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
- ]))
-
- # Each denseblock
- num_features = num_init_features
- for i, num_layers in enumerate(block_config):
- block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
- bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
- self.features.add_module('denseblock%d' % (i + 1), block)
- num_features = num_features + num_layers * growth_rate
- if i != len(block_config) - 1:
- trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
- self.features.add_module('transition%d' % (i + 1), trans)
- num_features = num_features // 2
-
- # Final batch norm
- self.features.add_module('norm5', nn.BatchNorm2d(num_features))
-
- # Linear layer
- self.classifier = nn.Linear(num_features, num_classes)
-
- # Official init from torch repo.
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- nn.init.constant_(m.bias, 0)
-
- # def forward(self, x):
- # features = self.features(x)
- # out = F.relu(features, inplace=True)
- # out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
- # out = self.classifier(out)
- # return out
-
- def forward(self, x):
- out_list = []
- for layer in self.features.children():
- x = layer(x)
- if isinstance(layer, _DenseBlock):
- out_list.append(x)
- p1, p2, p3, p4 = out_list
- return p1, p2, p3, p4
-
-
-
-def densenet121(pretrained=False, **kwargs):
- r"""Densenet-121 model from
- `"Densely Connected Convolutional Networks" `_
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
- **kwargs)
- if pretrained:
- # '.'s are no longer allowed in module names, but pervious _DenseLayer
- # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
- # They are also in the checkpoints in model_urls. This pattern is used
- # to find such keys.
- pattern = re.compile(
- r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
- state_dict = model_zoo.load_url(model_urls['densenet121'])
- for key in list(state_dict.keys()):
- res = pattern.match(key)
- if res:
- new_key = res.group(1) + res.group(2)
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
- model.load_state_dict(state_dict)
- return model
-
-
-def densenet169(pretrained=False, **kwargs):
- r"""Densenet-169 model from
- `"Densely Connected Convolutional Networks" `_
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
- **kwargs)
- if pretrained:
- # '.'s are no longer allowed in module names, but pervious _DenseLayer
- # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
- # They are also in the checkpoints in model_urls. This pattern is used
- # to find such keys.
- pattern = re.compile(
- r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
- state_dict = model_zoo.load_url(model_urls['densenet169'])
- for key in list(state_dict.keys()):
- res = pattern.match(key)
- if res:
- new_key = res.group(1) + res.group(2)
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
- model.load_state_dict(state_dict)
- return model
-
-
-def densenet201(pretrained=False, **kwargs):
- r"""Densenet-201 model from
- `"Densely Connected Convolutional Networks" `_
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
- **kwargs)
- if pretrained:
- # '.'s are no longer allowed in module names, but pervious _DenseLayer
- # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
- # They are also in the checkpoints in model_urls. This pattern is used
- # to find such keys.
- pattern = re.compile(
- r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
- state_dict = model_zoo.load_url(model_urls['densenet201'])
- for key in list(state_dict.keys()):
- res = pattern.match(key)
- if res:
- new_key = res.group(1) + res.group(2)
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
- model.load_state_dict(state_dict)
- return model
-
-
-def densenet161(pretrained=False, **kwargs):
- r"""Densenet-161 model from
- `"Densely Connected Convolutional Networks" `_
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
- **kwargs)
- if pretrained:
- # '.'s are no longer allowed in module names, but pervious _DenseLayer
- # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
- # They are also in the checkpoints in model_urls. This pattern is used
- # to find such keys.
- pattern = re.compile(
- r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
- state_dict = model_zoo.load_url(model_urls['densenet161'])
- for key in list(state_dict.keys()):
- res = pattern.match(key)
- if res:
- new_key = res.group(1) + res.group(2)
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
- model.load_state_dict(state_dict)
- return model
-
diff --git a/ptocr/model/backbone/det_scnet.py b/ptocr/model/backbone/det_scnet.py
deleted file mode 100644
index 714b84b..0000000
--- a/ptocr/model/backbone/det_scnet.py
+++ /dev/null
@@ -1,294 +0,0 @@
-#-*- coding:utf-8 _*-
-"""
-@author:fxw
-@file: det_scnet.py
-@time: 2020/08/07
-"""
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.model_zoo as model_zoo
-
-__all__ = ['SCNet', 'scnet50', 'scnet50_v1d']
-
-model_urls = {
- 'scnet50': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet50-dc6a7e87.pth',
- 'scnet50_v1d': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet50_v1d-4109d1e1.pth',
-}
-
-class SCConv(nn.Module):
- def __init__(self, inplanes, planes, stride, padding, dilation, groups, pooling_r, norm_layer):
- super(SCConv, self).__init__()
- self.k2 = nn.Sequential(
- nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r,padding=1),
- nn.Conv2d(inplanes, planes, kernel_size=3, stride=1,
- padding=padding, dilation=dilation,
- groups=groups, bias=False),
- norm_layer(planes),
- )
- self.k3 = nn.Sequential(
- nn.Conv2d(inplanes, planes, kernel_size=3, stride=1,
- padding=padding, dilation=dilation,
- groups=groups, bias=False),
- norm_layer(planes),
- )
- self.k4 = nn.Sequential(
- nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
- padding=padding, dilation=dilation,
- groups=groups, bias=False),
- norm_layer(planes),
- )
-
- def forward(self, x):
- identity = x
-
- out = torch.sigmoid(torch.add(identity, F.interpolate(self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2)
- out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2)
- out = self.k4(out) # k4
-
- return out
-
-class SCBottleneck(nn.Module):
- """SCNet SCBottleneck
- """
- expansion = 4
- pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv.
-
- def __init__(self, inplanes, planes, stride=1, downsample=None,
- cardinality=1, bottleneck_width=32,
- avd=False, dilation=1, is_first=False,
- norm_layer=None):
- super(SCBottleneck, self).__init__()
- group_width = int(planes * (bottleneck_width / 64.)) * cardinality
- self.conv1_a = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
- self.bn1_a = norm_layer(group_width)
- self.conv1_b = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
- self.bn1_b = norm_layer(group_width)
- self.avd = avd and (stride > 1 or is_first)
-
- if self.avd:
- self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
- stride = 1
-
- self.k1 = nn.Sequential(
- nn.Conv2d(
- group_width, group_width, kernel_size=3, stride=stride,
- padding=dilation, dilation=dilation,
- groups=cardinality, bias=False),
- norm_layer(group_width),
- )
-
- self.scconv = SCConv(
- group_width, group_width, stride=stride,
- padding=dilation, dilation=dilation,
- groups=cardinality, pooling_r=self.pooling_r, norm_layer=norm_layer)
-
- self.conv3 = nn.Conv2d(
- group_width * 2, planes * 4, kernel_size=1, bias=False)
- self.bn3 = norm_layer(planes*4)
-
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.dilation = dilation
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out_a= self.conv1_a(x)
- out_a = self.bn1_a(out_a)
- out_b = self.conv1_b(x)
- out_b = self.bn1_b(out_b)
- out_a = self.relu(out_a)
- out_b = self.relu(out_b)
-
- out_a = self.k1(out_a)
- out_b = self.scconv(out_b)
- out_a = self.relu(out_a)
- out_b = self.relu(out_b)
-
- if self.avd:
- out_a = self.avd_layer(out_a)
- out_b = self.avd_layer(out_b)
-
- out = self.conv3(torch.cat([out_a, out_b], dim=1))
- out = self.bn3(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
-class SCNet(nn.Module):
- """ SCNet Variants Definations
- Parameters
- ----------
- block : Block
- Class for the residual block.
- layers : list of int
- Numbers of layers in each block.
- classes : int, default 1000
- Number of classification classes.
- dilated : bool, default False
- Applying dilation strategy to pretrained SCNet yielding a stride-8 model.
- deep_stem : bool, default False
- Replace 7x7 conv in input stem with 3 3x3 conv.
- avg_down : bool, default False
- Use AvgPool instead of stride conv when
- downsampling in the bottleneck.
- norm_layer : object
- Normalization layer used (default: :class:`torch.nn.BatchNorm2d`).
- Reference:
- - He, Kaiming, et al. "Deep residual learning for image recognition."
- Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
- """
- def __init__(self, block, layers, groups=1, bottleneck_width=32,
- num_classes=1000, dilated=False, dilation=1,
- deep_stem=False, stem_width=64, avg_down=False,
- avd=False, norm_layer=nn.BatchNorm2d):
- self.cardinality = groups
- self.bottleneck_width = bottleneck_width
- # ResNet-D params
- self.inplanes = stem_width*2 if deep_stem else 64
- self.avg_down = avg_down
- self.avd = avd
-
- super(SCNet, self).__init__()
- conv_layer = nn.Conv2d
- if deep_stem:
- self.conv1 = nn.Sequential(
- conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
- norm_layer(stem_width),
- nn.ReLU(inplace=True),
- conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
- norm_layer(stem_width),
- nn.ReLU(inplace=True),
- conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False),
- )
- else:
- self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
- bias=False)
- self.bn1 = norm_layer(self.inplanes)
- self.relu = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
- if dilated or dilation == 4:
- self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
- dilation=2, norm_layer=norm_layer)
- self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
- dilation=4, norm_layer=norm_layer)
- elif dilation==2:
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
- dilation=1, norm_layer=norm_layer)
- self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
- dilation=2, norm_layer=norm_layer)
- else:
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
- norm_layer=norm_layer)
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
- norm_layer=norm_layer)
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Linear(512 * block.expansion, num_classes)
-
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, norm_layer):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
- is_first=True):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- down_layers = []
- if self.avg_down:
- if dilation == 1:
- down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
- ceil_mode=True, count_include_pad=False))
- else:
- down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
- ceil_mode=True, count_include_pad=False))
- down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=1, bias=False))
- else:
- down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False))
- down_layers.append(norm_layer(planes * block.expansion))
- downsample = nn.Sequential(*down_layers)
-
- layers = []
- if dilation == 1 or dilation == 2:
- layers.append(block(self.inplanes, planes, stride, downsample=downsample,
- cardinality=self.cardinality,
- bottleneck_width=self.bottleneck_width,
- avd=self.avd, dilation=1, is_first=is_first,
- norm_layer=norm_layer))
- elif dilation == 4:
- layers.append(block(self.inplanes, planes, stride, downsample=downsample,
- cardinality=self.cardinality,
- bottleneck_width=self.bottleneck_width,
- avd=self.avd, dilation=2, is_first=is_first,
- norm_layer=norm_layer))
- else:
- raise RuntimeError("=> unknown dilation size: {}".format(dilation))
-
- self.inplanes = planes * block.expansion
- for i in range(1, blocks):
- layers.append(block(self.inplanes, planes,
- cardinality=self.cardinality,
- bottleneck_width=self.bottleneck_width,
- avd=self.avd, dilation=dilation,
- norm_layer=norm_layer))
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
-
- x2 = self.layer1(x)
- x3 = self.layer2(x2)
- x4 = self.layer3(x3)
- x5 = self.layer4(x4)
-
- return x2, x3, x4, x5
-
-
-
-def scnet50(pretrained=False, **kwargs):
- """Constructs a SCNet-50 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = SCNet(SCBottleneck, [3, 4, 6, 3],
- deep_stem=False, stem_width=32, avg_down=False,
- avd=False, **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(model_urls['scnet50']),strict=False)
- return model
-
-def scnet50_v1d(pretrained=False, **kwargs):
- """Constructs a SCNet-50_v1d model described in
- `Bag of Tricks `_.
- `ResNeSt: Split-Attention Networks `_.
- Compared with default SCNet(SCNetv1b), SCNetv1d replaces the 7x7 conv
- in the input stem with three 3x3 convs. And in the downsampling block,
- a 3x3 avg_pool with stride 2 is added before conv, whose stride is
- changed to 1.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = SCNet(SCBottleneck, [3, 4, 6, 3],
- deep_stem=True, stem_width=32, avg_down=True,
- avd=True, **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(model_urls['scnet50_v1d']),strict=False)
- return model
diff --git a/ptocr/model/backbone/rec_crnn_backbone.py b/ptocr/model/backbone/rec_crnn_backbone.py
deleted file mode 100644
index c401800..0000000
--- a/ptocr/model/backbone/rec_crnn_backbone.py
+++ /dev/null
@@ -1,79 +0,0 @@
-#-*- coding:utf-8 _*-
-"""
-@author:fxw
-@file: crnn_backbone.py
-@time: 2020/10/12
-"""
-import torch
-import torch.nn as nn
-
-class conv_bn_relu(nn.Module):
- def __init__(self,in_c,out_c,k_s,s,p,with_bn=True):
- super(conv_bn_relu, self).__init__()
- self.conv = nn.Conv2d(in_c,out_c,k_s,s,p)
- self.bn = nn.BatchNorm2d(out_c)
- self.relu = nn.ReLU()
- self.with_bn = with_bn
- def forward(self, x):
- x = self.conv(x)
- if self.with_bn:
- x = self.bn(x)
- x = self.relu(x)
- return x
-
-class crnn_backbone(nn.Module):
- def __init__(self,is_gray):
- super(crnn_backbone, self).__init__()
- if(is_gray):
- nc = 1
- else:
- nc = 3
- base_channel = 64
- self.cnn = nn.Sequential(
- conv_bn_relu(nc,base_channel,3,1,1),
- nn.MaxPool2d(2,2),
- conv_bn_relu(base_channel,base_channel*2,3,1,1),
- nn.MaxPool2d(2, 2),
- conv_bn_relu(base_channel*2,base_channel*4,3,1,1),
- conv_bn_relu(base_channel*4,base_channel*4,3,1,1),
- nn.MaxPool2d((2,1),(2,1)),
- conv_bn_relu(base_channel*4,base_channel*8,3,1,1,with_bn=True),
- conv_bn_relu(base_channel*8,base_channel*8,3,1,1,with_bn=True),
- nn.MaxPool2d((2,1), (2,1)),
- conv_bn_relu(base_channel*8,base_channel*8,(2,1),1,0)
- # conv_bn_relu(base_channel*8,base_channel*8,2,1,0)
- )
-
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight.data)
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1.)
- m.bias.data.fill_(1e-4)
-
- def forward(self, x):
- x = self.cnn(x)
- return x
-
-def rec_crnn_backbone(pretrained=False, is_gray=False,**kwargs):
- """VGG 19-layer model (configuration 'E') with batch normalization
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- if pretrained:
- kwargs['init_weights'] = False
- model = crnn_backbone(is_gray)
- if pretrained:
- pretrained_model = torch.load('./pre_model/crnn_backbone.pth')
- state = model.state_dict()
- for key in state.keys():
- if key in pretrained_model.keys():
- if (key=='features.0.weight' and is_gray):
- state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1)
- else:
- state[key] = pretrained_model[key]
- model.load_state_dict(state)
- return model
-
-
diff --git a/ptocr/model/backbone/rec_mobilev3_bd.py b/ptocr/model/backbone/rec_mobilev3_bd.py
new file mode 100644
index 0000000..417abba
--- /dev/null
+++ b/ptocr/model/backbone/rec_mobilev3_bd.py
@@ -0,0 +1,279 @@
+import torch.nn.functional as F
+import torch.nn as nn
+
+
+class hswish(nn.Module):
+ def forward(self, x):
+ out = x * F.relu6(x + 3, inplace=True) / 6
+ return out
+
+class hsigmoid(nn.Module):
+ def forward(self, x):
+ out = F.relu6(x + 3, inplace=True) / 6
+ return out
+
+def make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+class ConvBNLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None):
+ super(ConvBNLayer, self).__init__()
+ self.if_act = if_act
+ self.act = act
+ if self.act == "hardswish":
+ self.hardswish = hswish()
+
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups)
+
+ self.bn = nn.BatchNorm2d(out_channels)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.if_act:
+ if self.act == "relu":
+ x = F.relu(x)
+ elif self.act == "hardswish":
+ x = self.hardswish(x)
+ else:
+ print("The activation function({}) is selected incorrectly.".
+ format(self.act))
+ exit()
+ return x
+
+class ResidualUnit(nn.Module):
+ def __init__(self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ use_se,
+ act=None):
+ super(ResidualUnit, self).__init__()
+ self.if_shortcut = stride == 1 and in_channels == out_channels
+ self.if_se = use_se
+
+ self.expand_conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ if_act=True,
+ act=act)
+ self.bottleneck_conv = ConvBNLayer(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=int((kernel_size - 1) // 2),
+ groups=mid_channels,
+ if_act=True,
+ act=act)
+ if self.if_se:
+ self.mid_se = SEModule(mid_channels)
+ self.linear_conv = ConvBNLayer(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ if_act=False,
+ act=None)
+
+ def forward(self, inputs):
+ x = self.expand_conv(inputs)
+ x = self.bottleneck_conv(x)
+ if self.if_se:
+ x = self.mid_se(x)
+ x = self.linear_conv(x)
+ if self.if_shortcut:
+ x = inputs+x
+ return x
+
+
+class SEModule(nn.Module):
+ def __init__(self, in_channels, reduction=4):
+ super(SEModule, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=in_channels // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.conv2 = nn.Conv2d(
+ in_channels=in_channels // reduction,
+ out_channels=in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.hardsigmoid = hsigmoid()
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = self.hardsigmoid(outputs)
+ return inputs * outputs
+
+class MobileNetV3(nn.Module):
+ def __init__(self,
+ in_channels=3,
+ model_name='small',
+ scale=0.5,
+ large_stride=None,
+ small_stride=None,
+ **kwargs):
+ super(MobileNetV3, self).__init__()
+ if small_stride is None:
+ small_stride = [1, 2, 2, 2]
+ if large_stride is None:
+ large_stride = [1, 2, 2, 2]
+
+ assert isinstance(large_stride, list), "large_stride type must " \
+ "be list but got {}".format(type(large_stride))
+ assert isinstance(small_stride, list), "small_stride type must " \
+ "be list but got {}".format(type(small_stride))
+ assert len(large_stride) == 4, "large_stride length must be " \
+ "4 but got {}".format(len(large_stride))
+ assert len(small_stride) == 4, "small_stride length must be " \
+ "4 but got {}".format(len(small_stride))
+
+ if model_name == "large":
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, False, 'relu', large_stride[0]],
+ [3, 64, 24, False, 'relu', (large_stride[1], 1)],
+ [3, 72, 24, False, 'relu', 1],
+ [5, 72, 40, True, 'relu', (large_stride[2], 1)],
+ [5, 120, 40, True, 'relu', 1],
+ [5, 120, 40, True, 'relu', 1],
+ [3, 240, 80, False, 'hardswish', 1],
+ [3, 200, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 480, 112, True, 'hardswish', 1],
+ [3, 672, 112, True, 'hardswish', 1],
+ [5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
+ [5, 960, 160, True, 'hardswish', 1],
+ [5, 960, 160, True, 'hardswish', 1],
+ ]
+ cls_ch_squeeze = 960
+ elif model_name == "small":
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, True, 'relu', (small_stride[0], 1)],
+ [3, 72, 24, False, 'relu', (small_stride[1], 1)],
+ [3, 88, 24, False, 'relu', 1],
+ [5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 120, 48, True, 'hardswish', 1],
+ [5, 144, 48, True, 'hardswish', 1],
+ [5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
+ [5, 576, 96, True, 'hardswish', 1],
+ [5, 576, 96, True, 'hardswish', 1],
+ ]
+ cls_ch_squeeze = 576
+ else:
+ raise NotImplementedError("mode[" + model_name +
+ "_model] is not implemented!")
+
+ supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
+ assert scale in supported_scale, \
+ "supported scales are {} but input scale is {}".format(supported_scale, scale)
+
+ inplanes = 16
+ # conv1
+ self.conv1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=make_divisible(inplanes * scale),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=1,
+ if_act=True,
+ act='hardswish')
+ i = 0
+ block_list = []
+ inplanes = make_divisible(inplanes * scale)
+ for (k, exp, c, se, nl, s) in cfg:
+ block_list.append(
+ ResidualUnit(
+ in_channels=inplanes,
+ mid_channels=make_divisible(scale * exp),
+ out_channels=make_divisible(scale * c),
+ kernel_size=k,
+ stride=s,
+ use_se=se,
+ act=nl))
+ inplanes = make_divisible(scale * c)
+ i += 1
+ self.blocks = nn.Sequential(*block_list)
+
+ self.conv2 = ConvBNLayer(
+ in_channels=inplanes,
+ out_channels=make_divisible(scale * cls_ch_squeeze),
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ if_act=True,
+ act='hardswish')
+
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ self.out_channels = make_divisible(scale * cls_ch_squeeze)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.blocks(x)
+ x = self.conv2(x)
+ x = self.pool(x)
+ return x
+
+
+def mobilenet_v3_small(pretrained, is_gray=False,**kwargs):
+ if is_gray:
+ in_channels = 1
+ else:
+ in_channels = 3
+ model = MobileNetV3( in_channels=in_channels,
+ model_name='small',
+ scale = 1)
+ if pretrained:
+ pass
+ return model
+
+def mobilenet_v3_large(pretrained,is_gray=False,**kwargs):
+ if is_gray:
+ in_channels = 1
+ else:
+ in_channels = 3
+ model = MobileNetV3( in_channels=in_channels,
+ model_name='large',
+ scale=1)
+ if pretrained:
+ pass
+ return model
\ No newline at end of file
diff --git a/ptocr/model/backbone/rec_vgg.py b/ptocr/model/backbone/rec_vgg.py
deleted file mode 100644
index 40d1c95..0000000
--- a/ptocr/model/backbone/rec_vgg.py
+++ /dev/null
@@ -1,258 +0,0 @@
-#-*- coding:utf-8 _*-
-"""
-@author:fxw
-@file: vgg.py
-@time: 2020/07/24
-"""
-import torch
-import torch.nn as nn
-import torch.utils.model_zoo as model_zoo
-
-
-__all__ = [
- 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
- 'vgg19_bn', 'vgg19',
-]
-
-
-model_urls = {
- 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
- 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
- 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
- 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
- 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
- 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
- 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
- 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
-}
-
-
-class VGG(nn.Module):
-
- def __init__(self, features, num_classes=1000, init_weights=True):
- super(VGG, self).__init__()
- self.features = features
- if init_weights:
- self._initialize_weights()
-
- def forward(self, x):
- x = self.features(x)
- return x
-
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- nn.init.normal_(m.weight, 0, 0.01)
- nn.init.constant_(m.bias, 0)
-
-
-def make_layers(cfg, is_gray=False,batch_norm=False):
- layers = []
- in_channels = 3
- if is_gray:
- in_channels = 1
- for v in cfg:
- if v == 'M':
- layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
- elif v=='N':
- layers += [nn.MaxPool2d(kernel_size=2, stride=(2,1))]
- else:
- conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
- if batch_norm:
- layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
- else:
- layers += [conv2d, nn.ReLU(inplace=True)]
- in_channels = v
- return nn.Sequential(*layers)
-
-
-cfg = {
- 'A': [64, 'M', 128, 'M', 256, 256, 'N', 512, 512, 'N', 512, 512, 'N'],
- 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'N', 512, 512, 'N', 512, 512, 'N'],
- 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'N', 512, 512, 512, 'N', 512, 512, 512, 'N'],
- 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'N', 512, 512, 512, 512, 'N'],
-}
-
-
-def vgg11(pretrained=False,is_gray=False, **kwargs):
- """VGG 11-layer model (configuration "A")
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- if pretrained:
- kwargs['init_weights'] = False
- model = VGG(make_layers(cfg['A'],is_gray=is_gray), **kwargs)
- if pretrained:
- pretrained_model = model_zoo.load_url(model_urls['vgg11'])
- state = model.state_dict()
- for key in state.keys():
- if key in pretrained_model.keys():
- if (key=='features.0.weight' and is_gray):
- state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1)
- else:
- state[key] = pretrained_model[key]
- model.load_state_dict(state)
- return model
-
-
-def vgg11_bn(pretrained=False,is_gray=False, **kwargs):
- """VGG 11-layer model (configuration "A") with batch normalization
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- if pretrained:
- kwargs['init_weights'] = False
- model = VGG(make_layers(cfg['A'], is_gray=is_gray,batch_norm=True), **kwargs)
- if pretrained:
- pretrained_model = model_zoo.load_url(model_urls['vgg11_bn'])
- state = model.state_dict()
- for key in state.keys():
- if key in pretrained_model.keys():
- if (key=='features.0.weight' and is_gray):
- state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1)
- else:
- state[key] = pretrained_model[key]
- model.load_state_dict(state)
- return model
-
-
-def vgg13(pretrained=False,is_gray=False, **kwargs):
- """VGG 13-layer model (configuration "B")
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- if pretrained:
- kwargs['init_weights'] = False
- model = VGG(make_layers(cfg['B'],is_gray=is_gray), **kwargs)
- if pretrained:
- pretrained_model = model_zoo.load_url(model_urls['vgg13'])
- state = model.state_dict()
- for key in state.keys():
- if key in pretrained_model.keys():
- if (key=='features.0.weight' and is_gray):
- state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1)
- else:
- state[key] = pretrained_model[key]
- model.load_state_dict(state)
- return model
-
-
-def vgg13_bn(pretrained=False, is_gray=False,**kwargs):
- """VGG 13-layer model (configuration "B") with batch normalization
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- if pretrained:
- kwargs['init_weights'] = False
- model = VGG(make_layers(cfg['B'],is_gray=is_gray, batch_norm=True), **kwargs)
- if pretrained:
- pretrained_model = model_zoo.load_url(model_urls['vgg13_bn'])
- state = model.state_dict()
- for key in state.keys():
- if key in pretrained_model.keys():
- if (key=='features.0.weight' and is_gray):
- state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1)
- else:
- state[key] = pretrained_model[key]
- model.load_state_dict(state)
- return model
-
-
-def vgg16(pretrained=False,is_gray=False, **kwargs):
- """VGG 16-layer model (configuration "D")
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- if pretrained:
- kwargs['init_weights'] = False
- model = VGG(make_layers(cfg['D'],is_gray=is_gray), **kwargs)
- if pretrained:
- pretrained_model = model_zoo.load_url(model_urls['vgg16'])
- state = model.state_dict()
- for key in state.keys():
- if key in pretrained_model.keys():
- if (key=='features.0.weight' and is_gray):
- state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1)
- else:
- state[key] = pretrained_model[key]
- model.load_state_dict(state)
- return model
-
-
-def vgg16_bn(pretrained=False,is_gray=False, **kwargs):
- """VGG 16-layer model (configuration "D") with batch normalization
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- if pretrained:
- kwargs['init_weights'] = False
- model = VGG(make_layers(cfg['D'],is_gray=is_gray, batch_norm=True), **kwargs)
- if pretrained:
- pretrained_model = model_zoo.load_url(model_urls['vgg16_bn'])
- state = model.state_dict()
- for key in state.keys():
- if key in pretrained_model.keys():
- if (key=='features.0.weight' and is_gray):
- state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1)
- else:
- state[key] = pretrained_model[key]
- model.load_state_dict(state)
- return model
-
-
-def vgg19(pretrained=False, is_gray=False,**kwargs):
- """VGG 19-layer model (configuration "E")
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- if pretrained:
- kwargs['init_weights'] = False
- model = VGG(make_layers(cfg['E'],is_gray=is_gray), **kwargs)
- if pretrained:
- pretrained_model = model_zoo.load_url(model_urls['vgg19'])
- state = model.state_dict()
- for key in state.keys():
- if key in pretrained_model.keys():
- if (key=='features.0.weight' and is_gray):
- state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1)
- else:
- state[key] = pretrained_model[key]
- model.load_state_dict(state)
- return model
-
-
-def vgg19_bn(pretrained=False, is_gray=False,**kwargs):
- """VGG 19-layer model (configuration 'E') with batch normalization
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- if pretrained:
- kwargs['init_weights'] = False
- model = VGG(make_layers(cfg['E'],is_gray=is_gray, batch_norm=True), **kwargs)
- if pretrained:
- pretrained_model = model_zoo.load_url(model_urls['vgg19_bn'])
- state = model.state_dict()
- for key in state.keys():
- if key in pretrained_model.keys():
- if (key=='features.0.weight' and is_gray):
- state[key] = torch.mean(pretrained_model[key],1).unsqueeze(1)
- else:
- state[key] = pretrained_model[key]
- model.load_state_dict(state)
- return model
-
diff --git a/ptocr/model/backbone/reg_resnet.py b/ptocr/model/backbone/reg_resnet.py
deleted file mode 100644
index 832aa58..0000000
--- a/ptocr/model/backbone/reg_resnet.py
+++ /dev/null
@@ -1,368 +0,0 @@
-#-*- coding:utf-8 _*-
-"""
-@author:fxw
-@file: det_resnet.py.py
-@time: 2020/08/07
-"""
-import torch
-import torch.nn as nn
-import math
-import torch.utils.model_zoo as model_zoo
-
-BatchNorm2d = nn.BatchNorm2d
-
-__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
- 'resnet152']
-
-model_urls = {
- 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
- 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
- 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
- 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
- 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
-}
-
-
-def load_pre_model(model,pre_model_path):
- pre_dict = torch.load(pre_model_path)
- model_pre_dict = {}
- for key in model.state_dict().keys():
- if('model.module.backbone.'+key in pre_dict.keys()):
- model_pre_dict[key] = pre_dict['model.module.backbone.'+key]
- else:
- model_pre_dict[key] = model.state_dict()[key]
- model.load_state_dict(model_pre_dict)
- return model
-
-
-def constant_init(module, constant, bias=0):
- nn.init.constant_(module.weight, constant)
- if hasattr(module, 'bias'):
- nn.init.constant_(module.bias, bias)
-
-
-def conv3x3(in_planes, out_planes, stride=1):
- """3x3 convolution with padding"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=1, bias=False)
-
-
-class BasicBlock(nn.Module):
- expansion = 1
-
- def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
- super(BasicBlock, self).__init__()
- self.with_dcn = dcn is not None
- self.conv1 = conv3x3(inplanes, planes, stride)
- self.bn1 = BatchNorm2d(planes)
- self.relu = nn.ReLU(inplace=True)
- self.with_modulated_dcn = False
- if self.with_dcn:
- fallback_on_stride = dcn.get('fallback_on_stride', False)
- self.with_modulated_dcn = dcn.get('modulated', False)
- # self.conv2 = conv3x3(planes, planes)
- if not self.with_dcn or fallback_on_stride:
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
- padding=1, bias=False)
- else:
- deformable_groups = dcn.get('deformable_groups', 1)
- if not self.with_modulated_dcn:
- from models.dcn import DeformConv
- conv_op = DeformConv
- offset_channels = 18
- else:
- from models.dcn import ModulatedDeformConv
- conv_op = ModulatedDeformConv
- offset_channels = 27
- self.conv2_offset = nn.Conv2d(
- planes,
- deformable_groups * offset_channels,
- kernel_size=3,
- padding=1)
- self.conv2 = conv_op(
- planes,
- planes,
- kernel_size=3,
- padding=1,
- deformable_groups=deformable_groups,
- bias=False)
- self.bn2 = BatchNorm2d(planes)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- # out = self.conv2(out)
- if not self.with_dcn:
- out = self.conv2(out)
- elif self.with_modulated_dcn:
- offset_mask = self.conv2_offset(out)
- offset = offset_mask[:, :18, :, :]
- mask = offset_mask[:, -9:, :, :].sigmoid()
- out = self.conv2(out, offset, mask)
- else:
- offset = self.conv2_offset(out)
- out = self.conv2(out, offset)
- out = self.bn2(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
-
-class Bottleneck(nn.Module):
- expansion = 4
-
- def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
- super(Bottleneck, self).__init__()
- self.with_dcn = dcn is not None
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
- self.bn1 = BatchNorm2d(planes)
- fallback_on_stride = False
- self.with_modulated_dcn = False
- if self.with_dcn:
- fallback_on_stride = dcn.get('fallback_on_stride', False)
- self.with_modulated_dcn = dcn.get('modulated', False)
- if not self.with_dcn or fallback_on_stride:
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
- stride=stride, padding=1, bias=False)
- else:
- deformable_groups = dcn.get('deformable_groups', 1)
- if not self.with_modulated_dcn:
- from models.dcn import DeformConv
- conv_op = DeformConv
- offset_channels = 18
- else:
- from models.dcn import ModulatedDeformConv
- conv_op = ModulatedDeformConv
- offset_channels = 27
- self.conv2_offset = nn.Conv2d(
- planes, deformable_groups * offset_channels,
- kernel_size=3,
- padding=1)
- self.conv2 = conv_op(
- planes, planes, kernel_size=3, padding=1, stride=stride,
- deformable_groups=deformable_groups, bias=False)
- self.bn2 = BatchNorm2d(planes)
- self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
- self.bn3 = BatchNorm2d(planes * 4)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
- self.dcn = dcn
- self.with_dcn = dcn is not None
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- # out = self.conv2(out)
- if not self.with_dcn:
- out = self.conv2(out)
- elif self.with_modulated_dcn:
- offset_mask = self.conv2_offset(out)
- offset = offset_mask[:, :18, :, :]
- mask = offset_mask[:, -9:, :, :].sigmoid()
- out = self.conv2(out, offset, mask)
- else:
- offset = self.conv2_offset(out)
- out = self.conv2(out, offset)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
-
-class ResNet(nn.Module):
- def __init__(self, block, layers, num_classes=1000,
- dcn=None, stage_with_dcn=(False, False, False, False)):
- self.dcn = dcn
- self.stage_with_dcn = stage_with_dcn
- self.inplanes = 64
- super(ResNet, self).__init__()
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
- bias=False)
- self.bn1 = BatchNorm2d(64)
- self.relu = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2,1), padding=1)
- self.layer1 = self._make_layer(block, 64, layers[0])
- self.layer2 = self._make_layer(
- block, 128, layers[1], stride=2, dcn=dcn)
- self.layer3 = self._make_layer(
- block, 256, layers[2], stride=(2,1), dcn=dcn)
- self.layer4 = self._make_layer(
- block, 512, layers[3], stride=(2,1), dcn=dcn)
- # self.avgpool = nn.AvgPool2d(7, stride=1)
- # self.fc = nn.Linear(512 * block.expansion, num_classes)
-
- # self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1)
-
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(0, math.sqrt(2. / n))
- elif isinstance(m, BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- if self.dcn is not None:
- for m in self.modules():
- if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
- if hasattr(m, 'conv2_offset'):
- constant_init(m.conv2_offset, 0)
-
- def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
- BatchNorm2d(planes * block.expansion),
- )
-
- layers = []
- layers.append(block(self.inplanes, planes,
- stride, downsample, dcn=dcn))
- self.inplanes = planes * block.expansion
- for i in range(1, blocks):
- layers.append(block(self.inplanes, planes, dcn=dcn))
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- return x
-
-
-def resnet18(pretrained=True, load_url=False,**kwargs):
- """Constructs a ResNet-18 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
- if pretrained:
- if load_url:
- model.load_state_dict(model_zoo.load_url(
- model_urls['resnet50']), strict=False)
- else:
- model = load_pre_model(model,'./pre_model/pre-trained-model-synthtext-resnet18.pth')
- return model
-
-
-def deformable_resnet18(pretrained=True,load_url=False, **kwargs):
- """Constructs a ResNet-18 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(BasicBlock, [2, 2, 2, 2],
- dcn=dict(modulated=True,
- deformable_groups=1,
- fallback_on_stride=False),
- stage_with_dcn=[False, True, True, True], **kwargs)
- if pretrained:
- if load_url:
- model.load_state_dict(model_zoo.load_url(
- model_urls['resnet50']), strict=False)
- else:
- model = load_pre_model(model,'./pre_model/pre-trained-model-synthtext-resnet18.pth')
- return model
-
-
-def resnet34(pretrained=True, **kwargs):
- """Constructs a ResNet-34 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(
- model_urls['resnet34']), strict=False)
- return model
-
-
-def resnet50(pretrained=True,load_url=False,**kwargs):
- """Constructs a ResNet-50 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
- if pretrained:
- if load_url:
- model.load_state_dict(model_zoo.load_url(
- model_urls['resnet50']), strict=False)
- else:
- model = load_pre_model(model,'./pre_model/pre-trained-model-synthtext-resnet50.pth')
- return model
-
-
-def deformable_resnet50(pretrained=True,load_url=False, **kwargs):
- """Constructs a ResNet-50 model with deformable conv.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(Bottleneck, [3, 4, 6, 3],
- dcn=dict(modulated=True,
- deformable_groups=1,
- fallback_on_stride=False),
- stage_with_dcn=[False, True, True, True],
- **kwargs)
- if pretrained:
- if load_url:
- model.load_state_dict(model_zoo.load_url(
- model_urls['resnet50']), strict=False)
- else:
- model = load_pre_model(model,'./pre_model/pre-trained-model-synthtext-resnet50.pth')
- return model
-
-
-def resnet101(pretrained=True, **kwargs):
- """Constructs a ResNet-101 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(
- model_urls['resnet101']), strict=False)
- return model
-
-
-def resnet152(pretrained=True, **kwargs):
- """Constructs a ResNet-152 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(
- model_urls['resnet152']), strict=False)
- return model
diff --git a/ptocr/model/backbone/reg_resnet_bd.py b/ptocr/model/backbone/reg_resnet_bd.py
new file mode 100644
index 0000000..0840aa2
--- /dev/null
+++ b/ptocr/model/backbone/reg_resnet_bd.py
@@ -0,0 +1,267 @@
+import torch.nn.functional as F
+import torch.nn as nn
+
+class ConvBNLayer(nn.Module):
+ def __init__(self,in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_relu=False,
+ is_vd_mode=False):
+ super(ConvBNLayer,self).__init__()
+
+ self.is_vd_mode = is_vd_mode
+ self.is_relu = is_relu
+
+ if is_vd_mode:
+ self._pool2d_avg = nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0, ceil_mode=True)
+
+ self.conv = nn.Conv2d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=1 if is_vd_mode else stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups)
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU()
+
+ def forward(self,x):
+ if self.is_vd_mode:
+ x = self._pool2d_avg(x)
+ x = self.bn(self.conv(x))
+ if self.is_relu:
+ x = self.relu(x)
+ return x
+
+
+class BottleneckBlock(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ is_relu=True)
+
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ is_relu=True)
+
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1)
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ stride=stride,
+ is_vd_mode=not if_first and stride[0] != 1)
+
+ self.shortcut = shortcut
+
+ def forward(self,x):
+ y = self.conv0(x)
+ y = self.conv2(self.conv1(y))
+ if self.shortcut:
+ short = x
+ else:
+ short = self.short(x)
+ y = y+short
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False):
+ super(BasicBlock, self).__init__()
+ self.stride = stride
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ is_relu=True)
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3)
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=stride,
+ is_vd_mode=not if_first and stride[0] != 1)
+
+ self.shortcut = shortcut
+
+ def forward(self, x):
+ y = self.conv0(x)
+ y = self.conv1(y)
+
+ if self.shortcut:
+ short = x
+ else:
+ short = self.short(x)
+ y = y+short
+ y = F.relu(y)
+ return y
+
+class ResNet(nn.Module):
+ def __init__(self, in_channels=3, layers=50, **kwargs):
+ super(ResNet, self).__init__()
+
+ self.layers = layers
+ supported_layers = [18, 34, 50, 101, 152, 200]
+ assert layers in supported_layers, \
+ "supported layers are {} but input layer is {}".format(
+ supported_layers, layers)
+
+ if layers == 18:
+ depth = [2, 2, 2, 2]
+ elif layers == 34 or layers == 50:
+ depth = [3, 4, 6, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ elif layers == 200:
+ depth = [3, 12, 48, 3]
+ num_channels = [64, 256, 512,
+ 1024] if layers >= 50 else [64, 64, 128, 256]
+ num_filters = [64, 128, 256, 512]
+
+ self.conv1_1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ is_relu=True)
+ self.conv1_2 = ConvBNLayer(
+ in_channels=32,
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ is_relu=True)
+ self.conv1_3 = ConvBNLayer(
+ in_channels=32,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ is_relu=True)
+ self.pool2d_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ block_list = []
+ if layers >= 50:
+ for block in range(len(depth)):
+ shortcut = False
+ for i in range(depth[block]):
+ if i == 0 and block != 0:
+ stride = (2, 1)
+ else:
+ stride = (1, 1)
+ bottleneck_block = BottleneckBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block] * 4,
+ out_channels=num_filters[block],
+ stride=stride,
+ shortcut=shortcut,
+ if_first=block == i == 0)
+ shortcut = True
+ block_list.append(bottleneck_block)
+ self.out_channels = num_filters[block]
+ else:
+ for block in range(len(depth)):
+ shortcut = False
+ for i in range(depth[block]):
+ if i == 0 and block != 0:
+ stride = (2, 1)
+ else:
+ stride = (1, 1)
+
+ basic_block = BasicBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block],
+ out_channels=num_filters[block],
+ stride=stride,
+ shortcut=shortcut,
+ if_first=block == i == 0)
+ shortcut = True
+ block_list.append(basic_block)
+ self.out_channels = num_filters[block]
+ self.block = nn.Sequential(*block_list)
+ self.out_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+
+ def forward(self, x):
+ y = self.conv1_1(x)
+ y = self.conv1_2(y)
+ y = self.conv1_3(y)
+ y = self.pool2d_max(y)
+ for block in self.block:
+ y = block(y)
+ y = self.out_pool(y)
+ return y
+
+
+
+
+def resnet18(pretrained=False, is_gray=False,**kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ if is_gray:
+ in_channels = 1
+ else:
+ in_channels = 3
+ model = ResNet(in_channels=in_channels, layers=18, **kwargs)
+ if pretrained:
+ pass
+ return model
+
+def resnet34(pretrained=False, is_gray=False, **kwargs):
+ """Constructs a ResNet-34 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ if is_gray:
+ in_channels = 1
+ else:
+ in_channels = 3
+ model = ResNet(in_channels=in_channels, layers=34, **kwargs)
+ if pretrained:
+ pass
+ return model
+
+def resnet50(pretrained=False, is_gray=False, **kwargs):
+ """Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ if is_gray:
+ in_channels = 1
+ else:
+ in_channels = 3
+ model = ResNet(in_channels=in_channels, layers=50, **kwargs)
+ if pretrained:
+ pass
+ return model
\ No newline at end of file
diff --git a/ptocr/model/head/__pycache__/__init__.cpython-36.pyc b/ptocr/model/head/__pycache__/__init__.cpython-36.pyc
index c7e77d1..6776b1f 100644
Binary files a/ptocr/model/head/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/head/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc b/ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc
index 0820132..868b86e 100644
Binary files a/ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc and b/ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc differ
diff --git a/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc b/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc
index 1342a4b..5852ece 100644
Binary files a/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc and b/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc differ
diff --git a/ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc b/ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc
index 11e7a9e..5954053 100644
Binary files a/ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc and b/ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc differ
diff --git a/ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc b/ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc
index 86a2251..297f9e9 100644
Binary files a/ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc and b/ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc differ
diff --git a/ptocr/model/head/__pycache__/rec_FCHead.cpython-36.pyc b/ptocr/model/head/__pycache__/rec_FCHead.cpython-36.pyc
new file mode 100644
index 0000000..b66d2eb
Binary files /dev/null and b/ptocr/model/head/__pycache__/rec_FCHead.cpython-36.pyc differ
diff --git a/ptocr/model/head/det_SASTHead.py b/ptocr/model/head/det_SASTHead.py
index 98af128..dea784b 100644
--- a/ptocr/model/head/det_SASTHead.py
+++ b/ptocr/model/head/det_SASTHead.py
@@ -76,8 +76,8 @@ def __init__(self):
super(FPN_Down_Fusion, self).__init__()
self.fpn_down_conv1 = ConvBnRelu(3, 32, 1, 1, 0, with_relu=False)
-# self.fpn_down_conv2 = ConvBnRelu(128, 64, 1, 1, 0, with_relu=False) # for 3*3
- self.fpn_down_conv2 = ConvBnRelu(64, 64, 1, 1, 0, with_relu=False) # for 7*7
+ self.fpn_down_conv2 = ConvBnRelu(128, 64, 1, 1, 0, with_relu=False) # for 3*3
+# self.fpn_down_conv2 = ConvBnRelu(64, 64, 1, 1, 0, with_relu=False) # for 7*7
self.fpn_down_conv3 = ConvBnRelu(256, 128, 1, 1, 0, with_relu=False)
self.fpn_down_conv4 = ConvBnRelu(32, 64, 3, 2, 1, with_relu=False)
diff --git a/ptocr/model/head/rec_CRNNHead.py b/ptocr/model/head/rec_CRNNHead.py
index 36c4f54..6bc452e 100644
--- a/ptocr/model/head/rec_CRNNHead.py
+++ b/ptocr/model/head/rec_CRNNHead.py
@@ -1,107 +1,121 @@
-#-*- coding:utf-8 _*-
+# -*- coding:utf-8 _*-
"""
@author:fxw
@file: crnn_head.py
@time: 2020/07/24
"""
+import torch
import torch.nn as nn
from torch.nn import init
-class SeModule(nn.Module):
- def __init__(self, in_size, reduction=4):
- super(SeModule, self).__init__()
-
- self.se = nn.Sequential(
- nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(in_size // reduction),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(in_size),
- nn.Sigmoid())
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, mode='fan_out')
- if m.bias is not None:
- init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- init.constant_(m.weight, 1)
- init.constant_(m.bias, 0)
+
+class channelattention(nn.Module):
+ def __init__(self, time_step):
+ super(channelattention, self).__init__()
+ self.attention = nn.Sequential(nn.Linear(time_step, 8),
+ nn.Linear(8, 1))
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1)
+ att = self.attention(x)
+ att = torch.sigmoid(att)
+ out = att * x
+ return out.permute(2, 0, 1)
+
+
+class timeattention(nn.Module):
+ def __init__(self, inchannel):
+ super(timeattention, self).__init__()
+ self.attention = nn.Sequential(nn.Linear(inchannel, inchannel // 4),
+ nn.Linear(inchannel // 4, inchannel // 8),
+ nn.Linear(inchannel // 8, 1))
def forward(self, x):
- return x * self.se(x)
-
+ att = self.attention(x)
+ att = torch.sigmoid(att)
+ out = att * x
+ return out.permute(1, 0, 2)
+
+
class BLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BLSTM, self).__init__()
- self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
+ self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, batch_first=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
+ if not hasattr(self, '_flattened'):
+ self.rnn.flatten_parameters()
+ setattr(self, '_flattened', True)
recurrent, _ = self.rnn(input)
- T, b, h = recurrent.size()
- t_rec = recurrent.view(T * b, h)
+ b, T, h = recurrent.size()
+ t_rec = recurrent.contiguous().view(b * T, h)
- output = self.embedding(t_rec) # [T * b, nOut]
- output = output.view(T, b, -1)
+ output = self.embedding(t_rec) # [b * T, nOut]
+ output = output.contiguous().view(b, T, -1)
return output
+
class CRNN_Head(nn.Module):
- def __init__(self,use_conv=False,
- use_attention=False,
+ def __init__(self, use_attention=False,
use_lstm=True,
- lstm_num=2,
- inchannel=512,
- hiddenchannel=128,
- classes=1000):
- super(CRNN_Head,self).__init__()
+ time_step=25,
+ lstm_num=2,
+ inchannel=512,
+ hiddenchannel=128,
+ classes=1000):
+ super(CRNN_Head, self).__init__()
self.use_lstm = use_lstm
self.lstm_num = lstm_num
- self.use_conv = use_conv
- if use_attention:
- self.attention = SeModule(inchannel)
self.use_attention = use_attention
- if(use_lstm):
- assert lstm_num>0 ,Exception('lstm_num need to more than 0 if use_lstm = True')
+
+ if use_attention:
+ self.channel_attention = channelattention(time_step=time_step)
+ self.time_attention = timeattention(inchannel)
+ if (use_lstm):
+ assert lstm_num > 0, Exception('lstm_num need to more than 0 if use_lstm = True')
for i in range(lstm_num):
- if(i==0):
- if(lstm_num==1):
- setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel, hiddenchannel,classes))
+ if (i == 0):
+ if (lstm_num == 1):
+ setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel, hiddenchannel, classes))
else:
- setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel,hiddenchannel,hiddenchannel))
- elif(i==lstm_num-1):
+ setattr(self, 'lstm_{}'.format(i + 1), BLSTM(inchannel, hiddenchannel, hiddenchannel))
+ elif (i == lstm_num - 1):
setattr(self, 'lstm_{}'.format(i + 1), BLSTM(hiddenchannel, hiddenchannel, classes))
else:
setattr(self, 'lstm_{}'.format(i + 1), BLSTM(hiddenchannel, hiddenchannel, hiddenchannel))
- elif(use_conv):
- self.out = nn.Conv2d(inchannel, classes, kernel_size=1, padding=0)
else:
- self.out = nn.Linear(inchannel,classes)
+ self.out = nn.Linear(inchannel, classes)
def forward(self, x):
b, c, h, w = x.size()
assert h == 1, "the height of conv must be 1"
-
- if self.use_attention:
- x = self.attention(x)
-
- if(self.use_conv):
- x = self.out(x)
- x = x.squeeze(2)
- x = x.permute(2, 0, 1)
- return x
-
+
x = x.squeeze(2)
- x = x.permute(2, 0, 1) # [w, b, c]
+ x = x.permute(0, 2, 1) # [b, w, c]
+
+ ############
+ if self.use_attention:
+ x = self.channel_attention(x)
+ x = self.time_attention(x)
+
+ ############
+
+ feau = []
if self.use_lstm:
for i in range(self.lstm_num):
x = getattr(self, 'lstm_{}'.format(i + 1))(x)
+ feau.append(x)
else:
+ feau.append(x)
x = self.out(x)
- return x
-
+
+ return x, feau
+
+
def backward_hook(self, module, grad_input, grad_output):
for g in grad_input:
g[g != g] = 0
\ No newline at end of file
diff --git a/ptocr/model/head/rec_FCHead.py b/ptocr/model/head/rec_FCHead.py
new file mode 100644
index 0000000..6f6d95a
--- /dev/null
+++ b/ptocr/model/head/rec_FCHead.py
@@ -0,0 +1,106 @@
+import torch
+import torch.nn as nn
+
+class FCModule(nn.Module):
+ """FCModule
+ Args:
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=True,
+ activation='relu',
+ inplace=True,
+ dropout=None,
+ order=('fc', 'act')):
+ super(FCModule, self).__init__()
+ self.order = order
+ self.activation = activation
+ self.inplace = inplace
+
+ self.with_activatation = activation is not None
+ self.with_dropout = dropout is not None
+
+ self.fc = nn.Linear(in_channels, out_channels, bias)
+
+ # build activation layer
+ if self.with_activatation:
+ # TODO: introduce `act_cfg` and supports more activation layers
+ if self.activation not in ['relu', 'tanh']:
+ raise ValueError('{} is currently not supported.'.format(
+ self.activation))
+ if self.activation == 'relu':
+ self.activate = nn.ReLU(inplace=inplace)
+ elif self.activation == 'tanh':
+ self.activate = nn.Tanh()
+
+ if self.with_dropout:
+ self.dropout = nn.Dropout(p=dropout)
+
+ def forward(self, x):
+ if self.order == ('fc', 'act'):
+ x = self.fc(x)
+
+ if self.with_activatation:
+ x = self.activate(x)
+ elif self.order == ('act', 'fc'):
+ if self.with_activatation:
+ x = self.activate(x)
+ x = self.fc(x)
+
+ if self.with_dropout:
+ x = self.dropout(x)
+
+ return x
+
+class FCModules(nn.Module):
+ """FCModules
+ Args:
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=True,
+ activation='relu',
+ inplace=True,
+ dropouts=None,
+ num_fcs=1):
+ super().__init__()
+
+ if dropouts is not None:
+ assert num_fcs == len(dropouts)
+ dropout = dropouts[0]
+ else:
+ dropout = None
+
+ layers = [FCModule(in_channels, out_channels, bias, activation, inplace, dropout)]
+ for ii in range(1, num_fcs):
+ if dropouts is not None:
+ dropout = dropouts[ii]
+ else:
+ dropout = None
+ layers.append(FCModule(out_channels, out_channels, bias, activation, inplace, dropout))
+
+ self.block = nn.Sequential(*layers)
+
+ def forward(self, x):
+ feat = self.block(x)
+ return feat
+
+
+class FC_Head(nn.Module):
+ def __init__(self,in_channels,
+ out_channels,max_length,num_class):
+ super(FC_Head,self).__init__()
+ self.adpooling = nn.AdaptiveAvgPool2d(1)
+ self.fc_end = FCModules(in_channels=in_channels,out_channels=out_channels)
+ self.fc_out = nn.Linear(out_channels,(num_class+1)*(max_length+1))
+ self.num_class = num_class
+ self.max_length = max_length
+ def forward(self,x):
+ x = self.adpooling(x)
+ x = x.view(x.shape[0],-1)
+ x = self.fc_end(x)
+ x1 = self.fc_out(x)
+ x2 = x1.view(x1.shape[0],self.max_length+1,self.num_class+1)
+ return x2,x1
\ No newline at end of file
diff --git a/ptocr/model/loss/__pycache__/__init__.cpython-36.pyc b/ptocr/model/loss/__pycache__/__init__.cpython-36.pyc
index 5584573..6c83201 100644
Binary files a/ptocr/model/loss/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/loss/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc
index 20f2ca5..ca9da33 100644
Binary files a/ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc and b/ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc differ
diff --git a/ptocr/model/loss/__pycache__/centerloss.cpython-36.pyc b/ptocr/model/loss/__pycache__/centerloss.cpython-36.pyc
new file mode 100644
index 0000000..cf248be
Binary files /dev/null and b/ptocr/model/loss/__pycache__/centerloss.cpython-36.pyc differ
diff --git a/ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc
index bbc0036..d9b53cc 100644
Binary files a/ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc and b/ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc differ
diff --git a/ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc
index 648b897..d51b198 100644
Binary files a/ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc and b/ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc differ
diff --git a/ptocr/model/loss/__pycache__/fc_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/fc_loss.cpython-36.pyc
new file mode 100644
index 0000000..ae1ccdc
Binary files /dev/null and b/ptocr/model/loss/__pycache__/fc_loss.cpython-36.pyc differ
diff --git a/ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc b/ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc
index d7dfd62..bdc5329 100644
Binary files a/ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc and b/ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc differ
diff --git a/ptocr/model/loss/basical_loss.py b/ptocr/model/loss/basical_loss.py
index 649858a..0c74cd3 100644
--- a/ptocr/model/loss/basical_loss.py
+++ b/ptocr/model/loss/basical_loss.py
@@ -6,8 +6,42 @@
"""
import torch
import torch.nn as nn
+import torch.nn.functional as F
import numpy as np
+class MulClassLoss(nn.Module):
+ def __init__(self, ):
+ super(MulClassLoss, self).__init__()
+
+ def forward(self,pre_score,gt_score,n_class):
+ gt_score = gt_score.reshape(-1)
+ index = gt_score>0
+ if index.sum()>0:
+ pre_score = pre_score.permute(0,2,3,1).reshape(-1,n_class)
+ gt_score = gt_score[index]
+ gt_score = gt_score - 1
+ pre_score = pre_score[index]
+ class_loss = F.cross_entropy(pre_score,gt_score.long(), ignore_index=-1)
+ else:
+ class_loss = torch.tensor(0.0).cuda()
+ return class_loss
+
+
+class CrossEntropyLoss(nn.Module):
+
+ def __init__(self, weight=None, size_average=None, ignore_index=-100,
+ reduce=None, reduction='mean'):
+ super(CrossEntropyLoss, self).__init__()
+ self.criteron = nn.CrossEntropyLoss(weight=weight,
+ size_average=size_average,
+ ignore_index=ignore_index,
+ reduce=reduce,
+ reduction=reduction)
+
+ def forward(self, pred, target, *args):
+ return self.criteron(pred.contiguous().view(-1, pred.shape[-1]), target.to(pred.device).contiguous().view(-1))
+
+
class DiceLoss(nn.Module):
def __init__(self,eps=1e-6):
super(DiceLoss,self).__init__()
@@ -183,10 +217,12 @@ def forward(self,
return balance_loss
-def focal_ctc_loss(ctc_loss,alpha=0.25,gamma=0.5): # 0.99,1
+def focal_ctc_loss(ctc_loss,alpha=0.95,gamma=1): # 0.99,1
+# import pdb
+# pdb.set_trace()
prob = torch.exp(-ctc_loss)
focal_loss = alpha*(1-prob).pow(gamma)*ctc_loss
- return focal_loss.mean()
+ return focal_loss.sum()
class focal_bin_cross_entropy(nn.Module):
diff --git a/ptocr/model/loss/ctc_loss.py b/ptocr/model/loss/ctc_loss.py
index 3aa8c0f..cde7a8f 100644
--- a/ptocr/model/loss/ctc_loss.py
+++ b/ptocr/model/loss/ctc_loss.py
@@ -1,19 +1,22 @@
-# from torch.nn import CTCLoss as PytorchCTCLoss
-from warpctc_pytorch import CTCLoss as PytorchCTCLoss
import torch.nn as nn
-from .basical_loss import focal_ctc_loss
-
class CTCLoss(nn.Module):
def __init__(self,config):
super(CTCLoss,self).__init__()
-# self.criterion = PytorchCTCLoss(reduction = config['loss']['reduction'])
- self.criterion = PytorchCTCLoss()
self.config = config
+ if config['loss']['ctc_type'] == 'warpctc':
+ from warpctc_pytorch import CTCLoss as PytorchCTCLoss
+ self.criterion = PytorchCTCLoss()
+ else:
+ from torch.nn import CTCLoss as PytorchCTCLoss
+ self.criterion = PytorchCTCLoss(reduction = 'none')
def forward(self,pre_batch,gt_batch):
preds,preds_size = pre_batch['preds'],pre_batch['preds_size']
labels,labels_len = gt_batch['labels'],gt_batch['labels_len']
+ if self.config['loss']['ctc_type'] != 'warpctc':
+ preds = preds.log_softmax(2).requires_grad_() # torch.ctcloss
loss = self.criterion(preds, labels, preds_size, labels_len)
- if self.config['loss']['reduction']=='none':
- loss = focal_ctc_loss(loss)
+ if self.config['loss']['use_ctc_weight']:
+ loss = gt_batch['ctc_loss_weight']*loss.cuda()
+ loss = loss.sum()
return loss/self.config['trainload']['batch_size']
\ No newline at end of file
diff --git a/ptocr/model/loss/db_loss.py b/ptocr/model/loss/db_loss.py
index d8ed9e7..38dc872 100644
--- a/ptocr/model/loss/db_loss.py
+++ b/ptocr/model/loss/db_loss.py
@@ -6,7 +6,11 @@
"""
import torch
import torch.nn as nn
-from .basical_loss import MaskL1Loss,BalanceCrossEntropyLoss,DiceLoss,FocalCrossEntropyLoss
+from .basical_loss import MaskL1Loss,BalanceCrossEntropyLoss,DiceLoss,FocalCrossEntropyLoss,MulClassLoss
+
+
+
+
class DBLoss(nn.Module):
def __init__(self, l1_scale=10, bce_scale=1,eps=1e-6):
super(DBLoss, self).__init__()
@@ -27,4 +31,31 @@ def forward(self, pred_bach, gt_batch):
metrics.update(**l1_metric)
else:
loss = bce_loss
+ return loss, metrics
+
+class DBLossMul(nn.Module):
+ def __init__(self, n_class,l1_scale=10, bce_scale=1, class_scale=1,eps=1e-6):
+ super(DBLossMul, self).__init__()
+ self.dice_loss = DiceLoss(eps)
+ self.l1_loss = MaskL1Loss()
+ self.bce_loss = BalanceCrossEntropyLoss()
+ self.class_loss = MulClassLoss()
+ self.l1_scale = l1_scale
+ self.bce_scale = bce_scale
+ self.class_scale = class_scale
+ self.n_class = n_class
+
+ def forward(self, pred_bach, gt_batch):
+ bce_loss = self.bce_loss(pred_bach['binary'][:,0], gt_batch['gt'], gt_batch['mask'])
+ class_loss = self.class_loss(pred_bach['binary_class'] ,gt_batch['gt_class'],self.n_class)
+ metrics = dict(loss_bce=bce_loss)
+ if 'thresh' in pred_bach:
+ l1_loss, l1_metric = self.l1_loss(pred_bach['thresh'][:,0], gt_batch['thresh_map'], gt_batch['thresh_mask'])
+ dice_loss = self.dice_loss(pred_bach['thresh_binary'][:,0], gt_batch['gt'], gt_batch['mask'])
+ metrics['loss_thresh'] = dice_loss
+ metrics['loss_class'] = class_loss
+ loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale + class_loss * self.class_scale
+ metrics.update(**l1_metric)
+ else:
+ loss = bce_loss
return loss, metrics
\ No newline at end of file
diff --git a/ptocr/model/loss/fc_loss.py b/ptocr/model/loss/fc_loss.py
new file mode 100644
index 0000000..293b7e2
--- /dev/null
+++ b/ptocr/model/loss/fc_loss.py
@@ -0,0 +1,14 @@
+import torch
+import torch.nn as nn
+from .basical_loss import CrossEntropyLoss
+
+class FCLoss(nn.Module):
+ def __init__(self,ignore_index = -1):
+ super(FCLoss, self).__init__()
+ self.cross_entropy_loss = CrossEntropyLoss(ignore_index = ignore_index)
+
+
+ def forward(self, pred_bach, gt_batch):
+ loss = self.cross_entropy_loss(pred_bach['pred'],gt_batch['gt'])
+ metrics = dict(loss_fc=loss)
+ return loss, metrics
\ No newline at end of file
diff --git a/ptocr/model/segout/__pycache__/__init__.cpython-36.pyc b/ptocr/model/segout/__pycache__/__init__.cpython-36.pyc
index 6eb960b..876605d 100644
Binary files a/ptocr/model/segout/__pycache__/__init__.cpython-36.pyc and b/ptocr/model/segout/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc b/ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc
index 3368407..844eabc 100644
Binary files a/ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc and b/ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc differ
diff --git a/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc b/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc
index 882bfe8..5de7975 100644
Binary files a/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc and b/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc differ
diff --git a/ptocr/model/segout/det_DB_segout.py b/ptocr/model/segout/det_DB_segout.py
index 895472e..2d9ee69 100644
--- a/ptocr/model/segout/det_DB_segout.py
+++ b/ptocr/model/segout/det_DB_segout.py
@@ -88,3 +88,104 @@ def forward(self, fuse,img):
def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
+
+
+class SegDetectorMul(nn.Module):
+ def __init__(self,n_classes = 1,
+ inner_channels=256, k=10,
+ adaptive=False,
+ serial=False, bias=False,
+ *args, **kwargs):
+ '''
+ bias: Whether conv layers have bias or not.
+ adaptive: Whether to use adaptive threshold training or not.
+ smooth: If true, use bilinear instead of deconv.
+ serial: If true, thresh prediction will combine segmentation result as input.
+ '''
+ super(SegDetectorMul, self).__init__()
+ self.k = k
+ self.serial = serial
+
+ self.binarize = nn.Sequential(
+ nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
+ nn.BatchNorm2d(inner_channels // 4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
+ nn.BatchNorm2d(inner_channels // 4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2),
+ nn.Sigmoid()
+ )
+
+ self.classhead = nn.Sequential(
+ nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
+ nn.BatchNorm2d(inner_channels // 4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
+ nn.BatchNorm2d(inner_channels // 4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(inner_channels // 4, n_classes, 2, 2)
+ )
+
+
+ self.binarize.apply(self.weights_init)
+ self.classhead.apply(self.weights_init)
+
+ self.adaptive = adaptive
+ if adaptive:
+ self.thresh = self._init_thresh(
+ inner_channels, serial=serial,bias=bias)
+ self.thresh.apply(self.weights_init)
+
+ def weights_init(self, m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.kaiming_normal_(m.weight.data)
+ elif classname.find('BatchNorm') != -1:
+ m.weight.data.fill_(1.)
+ m.bias.data.fill_(1e-4)
+
+ def _init_thresh(self, inner_channels,
+ serial=False, bias=False):
+ in_channels = inner_channels
+ if serial:
+ in_channels += 1
+ self.thresh = nn.Sequential(
+ nn.Conv2d(in_channels, inner_channels //
+ 4, 3, padding=1, bias=bias),
+ nn.BatchNorm2d(inner_channels // 4),
+ nn.ReLU(inplace=True),
+ self._init_upsample(inner_channels // 4, inner_channels // 4),
+ nn.BatchNorm2d(inner_channels // 4),
+ nn.ReLU(inplace=True),
+ self._init_upsample(inner_channels // 4, 1),
+ nn.Sigmoid()
+ )
+ return self.thresh
+
+ def _init_upsample(self, in_channels, out_channels):
+ return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
+
+ def forward(self, fuse,img):
+
+ binary = self.binarize(fuse)
+ binary_class = self.classhead(fuse)
+
+ if self.training:
+ result = OrderedDict(binary=binary)
+ result.update(binary_class = binary_class)
+ else:
+ return binary,binary_class
+
+ if self.adaptive and self.training:
+ if self.serial:
+ fuse = torch.cat(
+ (fuse, nn.functional.interpolate(
+ binary, fuse.shape[2:])), 1)
+ thresh = self.thresh(fuse)
+ thresh_binary = self.step_function(binary, thresh)
+ result.update(thresh=thresh, thresh_binary=thresh_binary)
+ return result
+
+ def step_function(self, x, y):
+ return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
diff --git a/ptocr/optimizer.py b/ptocr/optimizer.py
index 2df2fcd..9c385f1 100644
--- a/ptocr/optimizer.py
+++ b/ptocr/optimizer.py
@@ -6,21 +6,21 @@
"""
import torch
-def AdamDecay(config,model):
- optimizer = torch.optim.Adam(model.parameters(), lr=config['optimizer']['base_lr'],
+def AdamDecay(config,parameters):
+ optimizer = torch.optim.Adam(parameters, lr=config['optimizer']['base_lr'],
betas=(config['optimizer']['beta1'], config['optimizer']['beta2']),
weight_decay=config['optimizer']['weight_decay'])
return optimizer
-def SGDDecay(config,model):
- optimizer = torch.optim.SGD(model.parameters(), lr=config['optimizer']['base_lr'],
+def SGDDecay(config,parameters):
+ optimizer = torch.optim.SGD(parameters, lr=config['optimizer']['base_lr'],
momentum=config['optimizer']['momentum'],
weight_decay=config['optimizer']['weight_decay'])
return optimizer
-def RMSPropDecay(config,model):
- optimizer = torch.optim.RMSprop(model.parameters(), lr=config['optimizer']['base_lr'],
+def RMSPropDecay(config,parameters):
+ optimizer = torch.optim.RMSprop(parameters, lr=config['optimizer']['base_lr'],
alpha=config['optimizer']['alpha'],
weight_decay=config['optimizer']['weight_decay'],
momentum=config['optimizer']['momentum'])
@@ -30,11 +30,32 @@ def RMSPropDecay(config,model):
def lr_poly(base_lr, epoch, max_epoch=1200, factor=0.9):
return base_lr*((1-float(epoch)/max_epoch)**(factor))
+
+def SGDR(lr_max,lr_min,T_cur,T_m,ratio=0.3):
+ """
+ :param lr_max: 最大学习率
+ :param lr_min: 最小学习率
+ :param T_cur: 当前的epoch或iter
+ :param T_m: 隔多少调整的一次
+ :param ratio: 最大学习率衰减比率
+ :return:
+ """
+ if T_cur % T_m == 0 and T_cur != 0:
+ lr_max = lr_max - lr_max * ratio
+ lr = lr_min+1/2*(lr_max-lr_min)*(1+math.cos((T_cur%T_m/T_m)*math.pi))
+ return lr,lr_max
+
+
def adjust_learning_rate_poly(config, optimizer, epoch):
lr = lr_poly(config['optimizer']['base_lr'], epoch,
config['base']['n_epoch'], config['optimizer_decay']['factor'])
optimizer.param_groups[0]['lr'] = lr
-
+
+def adjust_learning_rate_sgdr(config, optimizer, epoch):
+ lr,lr_max = SGDR(config['optimizer']['lr_max'],config['optimizer']['lr_min'],epoch,config['optimizer']['T_m'],config['optimizer']['ratio'])
+ optimizer.param_groups[0]['lr'] = lr
+ config['optimizer']['lr_max'] = lr_max
+
def adjust_learning_rate(config, optimizer, epoch):
if epoch in config['optimizer_decay']['schedule']:
adjust_lr = optimizer.param_groups[0]['lr'] * config['optimizer_decay']['gama']
diff --git a/ptocr/postprocess/DBpostprocess.py b/ptocr/postprocess/DBpostprocess.py
index e483e89..959c270 100644
--- a/ptocr/postprocess/DBpostprocess.py
+++ b/ptocr/postprocess/DBpostprocess.py
@@ -216,4 +216,215 @@ def __call__(self, pred, ratio_list):
boxes_batch.append(boxes)
score_batch.append(score)
- return boxes_batch, score_batch
\ No newline at end of file
+ return boxes_batch, score_batch
+
+class DBPostProcessMul(object):
+ """
+ The post process for Differentiable Binarization (DB).
+ """
+
+ def __init__(self, config):
+ self.thresh = config['postprocess']['thresh']
+ self.box_thresh = config['postprocess']['box_thresh']
+ self.max_candidates = config['postprocess']['max_candidates']
+ self.is_poly = config['postprocess']['is_poly']
+ self.unclip_ratio = config['postprocess']['unclip_ratio']
+ self.min_size = config['postprocess']['min_size']
+
+ def polygons_from_bitmap(self, pred,classes, _bitmap, dest_width, dest_height):
+ '''
+ _bitmap: single map with shape (1, H, W),
+ whose values are binarized as {0, 1}
+ '''
+
+
+ bitmap = _bitmap
+ pred = pred
+ height, width = bitmap.shape
+ boxes = []
+ scores = []
+
+ contours, _ = cv2.findContours(
+ (bitmap*255).astype(np.uint8),
+ cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+
+ for contour in contours[:self.max_candidates]:
+ epsilon = 0.001 * cv2.arcLength(contour, True)
+ approx = cv2.approxPolyDP(contour, epsilon, True)
+ points = approx.reshape((-1, 2))
+ if points.shape[0] < 4:
+ continue
+ # _, sside = self.get_mini_boxes(contour)
+ # if sside < self.min_size:
+ # continue
+ score = self.box_score_fast(pred,classes, points.reshape(-1, 2))
+ if self.box_thresh > score:
+ continue
+
+ if points.shape[0] > 2:
+ box = self.unclip(points, self.unclip_ratio)
+ if len(box) > 1:
+ continue
+ else:
+ continue
+ box ,type_class = box.reshape(-1, 2)
+ _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
+ if sside < self.min_size + 2:
+ continue
+
+ if not isinstance(dest_width, int):
+ dest_width = dest_width.item()
+ dest_height = dest_height.item()
+
+ box[:, 0] = np.clip(
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 1] = np.clip(
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ boxes.append(box.tolist())
+ scores.append(score)
+ return boxes, scores
+
+ def boxes_from_bitmap(self, pred,classes, _bitmap, dest_width, dest_height):
+ '''
+ _bitmap: single map with shape (1, H, W),
+ whose values are binarized as {0, 1}
+ '''
+
+ bitmap = _bitmap
+ height, width = bitmap.shape
+
+ outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+ if len(outs) == 3:
+ img, contours, _ = outs[0], outs[1], outs[2]
+ elif len(outs) == 2:
+ contours, _ = outs[0], outs[1]
+
+ num_contours = min(len(contours), self.max_candidates)
+ boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
+ scores = np.zeros((num_contours, ), dtype=np.float32)
+ type_classes = np.zeros((num_contours, ), dtype=np.float32)
+
+ for index in range(num_contours):
+ contour = contours[index]
+ points, sside = self.get_mini_boxes(contour)
+ if sside < self.min_size:
+ continue
+ points = np.array(points)
+ score,type_class = self.box_score_fast(pred,classes, points.reshape(-1, 2))
+ if self.box_thresh > score:
+ continue
+ box = self.unclip(points,self.unclip_ratio).reshape(-1, 1, 2)
+ box, sside = self.get_mini_boxes(box)
+ if sside < self.min_size + 2:
+ continue
+ box = np.array(box)
+ if not isinstance(dest_width, int):
+ dest_width = dest_width.item()
+ dest_height = dest_height.item()
+
+ box[:, 0] = np.clip(
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 1] = np.clip(
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ boxes[index, :, :] = box.astype(np.int16)
+ scores[index] = score
+ type_classes[index] = type_class
+ return boxes, scores,type_classes
+
+ def unclip(self, box, unclip_ratio=2):
+ poly = Polygon(box)
+ distance = poly.area * unclip_ratio / poly.length
+ offset = pyclipper.PyclipperOffset()
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ expanded = np.array(offset.Execute(distance))
+ return expanded
+
+ def get_mini_boxes(self, contour):
+ bounding_box = cv2.minAreaRect(contour)
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
+ if points[1][1] > points[0][1]:
+ index_1 = 0
+ index_4 = 1
+ else:
+ index_1 = 1
+ index_4 = 0
+ if points[3][1] > points[2][1]:
+ index_2 = 2
+ index_3 = 3
+ else:
+ index_2 = 3
+ index_3 = 2
+
+ box = [
+ points[index_1], points[index_2], points[index_3], points[index_4]
+ ]
+ return box, min(bounding_box[1])
+
+ def box_score_fast(self, bitmap,classes,_box):
+ h, w = bitmap.shape[:2]
+ box = _box.copy()
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
+
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
+ box[:, 0] = box[:, 0] - xmin
+ box[:, 1] = box[:, 1] - ymin
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
+ classes = classes[ymin:ymax + 1, xmin:xmax + 1]
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0],np.argmax(np.bincount(classes.reshape(-1).astype(np.int32)))
+
+ def __call__(self, pred,pred_class, ratio_list):
+ pred = pred[:, 0, :, :]
+ segmentation = pred > self.thresh
+ classes = pred_class[:, 0, :, :]
+
+ boxes_batch = []
+ score_batch = []
+ type_classes_batch = []
+ for batch_index in range(pred.shape[0]):
+ height, width = pred.shape[-2:]
+ if(self.is_poly):
+ tmp_boxes, tmp_scores = self.polygons_from_bitmap(
+ pred[batch_index], classes[batch_index],segmentation[batch_index], width, height)
+
+ boxes = []
+ score = []
+ for k in range(len(tmp_boxes)):
+ if tmp_scores[k] > self.box_thresh:
+ boxes.append(tmp_boxes[k])
+ score.append(tmp_scores[k])
+ if len(boxes) > 0:
+ ratio_w, ratio_h = ratio_list[batch_index]
+ for i in range(len(boxes)):
+ boxes[i] = np.array(boxes[i])
+ boxes[i][:, 0] = boxes[i][:, 0] * ratio_w
+ boxes[i][:, 1] = boxes[i][:, 1] * ratio_h
+
+ boxes_batch.append(boxes)
+ score_batch.append(score)
+ else:
+ tmp_boxes, tmp_scores,type_classes = self.boxes_from_bitmap(
+ pred[batch_index], classes[batch_index],segmentation[batch_index], width, height)
+
+ boxes = []
+ score = []
+ _classes = []
+ for k in range(len(tmp_boxes)):
+ if tmp_scores[k] > self.box_thresh:
+ boxes.append(tmp_boxes[k])
+ score.append(tmp_scores[k])
+ _classes.append(type_classes[k])
+ if len(boxes) > 0:
+ boxes = np.array(boxes)
+
+ ratio_w, ratio_h = ratio_list[batch_index]
+ boxes[:, :, 0] = boxes[:, :, 0] * ratio_w
+ boxes[:, :, 1] = boxes[:, :, 1] * ratio_h
+ type_classes_batch.append(_classes)
+ boxes_batch.append(boxes)
+ score_batch.append(score)
+ return boxes_batch,score_batch,type_classes_batch
\ No newline at end of file
diff --git a/ptocr/postprocess/__pycache__/DBpostprocess.cpython-36.pyc b/ptocr/postprocess/__pycache__/DBpostprocess.cpython-36.pyc
new file mode 100644
index 0000000..e6bd8b5
Binary files /dev/null and b/ptocr/postprocess/__pycache__/DBpostprocess.cpython-36.pyc differ
diff --git a/ptocr/postprocess/__pycache__/SASTpostprocess.cpython-36.pyc b/ptocr/postprocess/__pycache__/SASTpostprocess.cpython-36.pyc
new file mode 100644
index 0000000..9343f87
Binary files /dev/null and b/ptocr/postprocess/__pycache__/SASTpostprocess.cpython-36.pyc differ
diff --git a/ptocr/postprocess/__pycache__/__init__.cpython-36.pyc b/ptocr/postprocess/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..7936da5
Binary files /dev/null and b/ptocr/postprocess/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/postprocess/__pycache__/locality_aware_nms.cpython-36.pyc b/ptocr/postprocess/__pycache__/locality_aware_nms.cpython-36.pyc
new file mode 100644
index 0000000..2aa05e9
Binary files /dev/null and b/ptocr/postprocess/__pycache__/locality_aware_nms.cpython-36.pyc differ
diff --git a/ptocr/postprocess/dbprocess/__pycache__/__init__.cpython-36.pyc b/ptocr/postprocess/dbprocess/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..a00defc
Binary files /dev/null and b/ptocr/postprocess/dbprocess/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/postprocess/lanms/__pycache__/__init__.cpython-36.pyc b/ptocr/postprocess/lanms/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..19a2476
Binary files /dev/null and b/ptocr/postprocess/lanms/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/utils/__pycache__/__init__.cpython-36.pyc b/ptocr/utils/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..1528833
Binary files /dev/null and b/ptocr/utils/__pycache__/__init__.cpython-36.pyc differ
diff --git a/ptocr/utils/__pycache__/cal_iou_acc.cpython-36.pyc b/ptocr/utils/__pycache__/cal_iou_acc.cpython-36.pyc
new file mode 100644
index 0000000..43a58c4
Binary files /dev/null and b/ptocr/utils/__pycache__/cal_iou_acc.cpython-36.pyc differ
diff --git a/ptocr/utils/__pycache__/gen_teacher_model.cpython-36.pyc b/ptocr/utils/__pycache__/gen_teacher_model.cpython-36.pyc
new file mode 100644
index 0000000..c08392a
Binary files /dev/null and b/ptocr/utils/__pycache__/gen_teacher_model.cpython-36.pyc differ
diff --git a/ptocr/utils/__pycache__/logger.cpython-36.pyc b/ptocr/utils/__pycache__/logger.cpython-36.pyc
new file mode 100644
index 0000000..d1f7b05
Binary files /dev/null and b/ptocr/utils/__pycache__/logger.cpython-36.pyc differ
diff --git a/ptocr/utils/__pycache__/metrics.cpython-36.pyc b/ptocr/utils/__pycache__/metrics.cpython-36.pyc
new file mode 100644
index 0000000..69e442b
Binary files /dev/null and b/ptocr/utils/__pycache__/metrics.cpython-36.pyc differ
diff --git a/ptocr/utils/__pycache__/prune_script.cpython-36.pyc b/ptocr/utils/__pycache__/prune_script.cpython-36.pyc
new file mode 100644
index 0000000..8f331e1
Binary files /dev/null and b/ptocr/utils/__pycache__/prune_script.cpython-36.pyc differ
diff --git a/ptocr/utils/__pycache__/transform_label.cpython-36.pyc b/ptocr/utils/__pycache__/transform_label.cpython-36.pyc
new file mode 100644
index 0000000..b70e4b6
Binary files /dev/null and b/ptocr/utils/__pycache__/transform_label.cpython-36.pyc differ
diff --git a/ptocr/utils/__pycache__/util_function.cpython-36.pyc b/ptocr/utils/__pycache__/util_function.cpython-36.pyc
new file mode 100644
index 0000000..3e86539
Binary files /dev/null and b/ptocr/utils/__pycache__/util_function.cpython-36.pyc differ
diff --git a/ptocr/utils/gen_teacher_model.py b/ptocr/utils/gen_teacher_model.py
index c3f9e74..5a42767 100644
--- a/ptocr/utils/gen_teacher_model.py
+++ b/ptocr/utils/gen_teacher_model.py
@@ -30,7 +30,6 @@ def forward(self,pre_score,gt_score,train_mask):
dice_loss = torch.mean(d)
return 1 - dice_loss
-
def GetTeacherModel(args):
config = yaml.load(open(args.t_config, 'r', encoding='utf-8'), Loader=yaml.FullLoader)
model = create_module(config['architectures']['model_function'])(config)
@@ -41,18 +40,18 @@ def GetTeacherModel(args):
class DistilLoss(nn.Module):
def __init__(self):
+
super(DistilLoss, self).__init__()
self.mse = nn.MSELoss()
self.diceloss = DiceLoss()
self.ignore = ['thresh']
+
def forward(self, s_map, t_map):
loss = 0
-# import pdb
-# pdb.set_trace()
for key in s_map.keys():
if(key in self.ignore):
continue
- loss+=self.diceloss(s_map[key],t_map[key],torch.ones(t_map[key].shape).cuda())
+ loss += self.diceloss(s_map[key],t_map[key],torch.ones(t_map[key].shape).cuda())
return loss
diff --git a/ptocr/utils/logger.py b/ptocr/utils/logger.py
index 4ca3794..16bc333 100644
--- a/ptocr/utils/logger.py
+++ b/ptocr/utils/logger.py
@@ -3,6 +3,34 @@
# (C) Wei YANG 2017
from __future__ import absolute_import
+import logging
+
+class TrainLog(object):
+ def __init__(self,LOG_FILE):
+ file_handler = logging.FileHandler(LOG_FILE) #输出到文件
+ console_handler = logging.StreamHandler() #输出到控制台
+ file_handler.setLevel('INFO') #error以上才输出到文件
+ console_handler.setLevel('INFO') #info以上才输出到控制台
+
+ fmt = '%(asctime)s - %(funcName)s - %(lineno)s - %(levelname)s - %(message)s'
+ formatter = logging.Formatter(fmt)
+ file_handler.setFormatter(formatter) #设置输出内容的格式
+ console_handler.setFormatter(formatter)
+
+ logger = logging.getLogger('TrainLog')
+ logger.setLevel('INFO') #设置了这个才会把debug以上的输出到控制台
+
+ logger.addHandler(console_handler)
+ logger.addHandler(file_handler)
+ self.logger = logger
+
+ def error(self,char):
+ self.logger.error(char)
+ def debug(self,char):
+ self.logger.debug(char)
+ def info(self,char):
+ self.logger.info(char)
+
class Logger(object):
def __init__(self, fpath, title=None, resume=False):
self.file = None
diff --git a/ptocr/utils/transform_label.py b/ptocr/utils/transform_label.py
index a0d322e..27f4e11 100644
--- a/ptocr/utils/transform_label.py
+++ b/ptocr/utils/transform_label.py
@@ -11,6 +11,7 @@
# import chardet
import numpy as np
import sys
+import abc
def get_keys(key_path):
with open(key_path,'r',encoding='utf-8') as fid:
@@ -18,7 +19,7 @@ def get_keys(key_path):
lines = lines.strip('\n')
return lines
-
+#☯
class strLabelConverter(object):
"""Convert between str and label.
@@ -125,4 +126,73 @@ def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
- return res
\ No newline at end of file
+ return res
+
+
+
+class BaseConverter(object):
+
+ def __init__(self, character):
+ self.character = list(character)
+ self.dict = {}
+ for i, char in enumerate(self.character):
+ self.dict[char] = i
+
+ @abc.abstractmethod
+ def train_encode(self, *args, **kwargs):
+ '''encode text in train phase'''
+
+ @abc.abstractmethod
+ def test_encode(self, *args, **kwargs):
+ '''encode text in test phase'''
+
+ @abc.abstractmethod
+ def decode(self, *args, **kwargs):
+ '''decode label to text in train and test phase'''
+
+
+
+class FCConverter(BaseConverter):
+
+ def __init__(self, config):
+ batch_max_length = config['base']['max_length']
+ character = get_keys(config['trainload']['key_file'])
+ self.character = character
+ list_token = ['[s]']
+ ignore_token = ['[ignore]']
+ list_character = list(character)
+ self.batch_max_length = batch_max_length + 1
+ super(FCConverter, self).__init__(character=list_token + list_character + ignore_token)
+ self.ignore_index = self.dict[ignore_token[0]]
+
+ def encode(self, text):
+ length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
+ batch_text = torch.LongTensor(len(text), self.batch_max_length).fill_(self.ignore_index)
+ for i, t in enumerate(text):
+ text = list(t)
+ text.append('[s]')
+ text = [self.dict[char] for char in text]
+ if self.batch_max_length>=len(text):
+ batch_text[i][:len(text)] = torch.LongTensor(text)
+ else:
+ batch_text[i][:self.batch_max_length] = torch.LongTensor(text)[:self.batch_max_length]
+ batch_text_input = batch_text
+ batch_text_target = batch_text
+
+ return batch_text_input, torch.IntTensor(length), batch_text_target
+
+ def train_encode(self, text):
+ return self.encode(text)
+
+ def test_encode(self, text):
+ return self.encode(text)
+
+ def decode(self, text_index):
+ texts = []
+ batch_size = text_index.shape[0]
+ for index in range(batch_size):
+ text = ''.join([self.character[i] for i in text_index[index, :]])
+ text = text[:text.find('[s]')]
+ texts.append(text)
+
+ return texts
\ No newline at end of file
diff --git a/ptocr/utils/util_function.py b/ptocr/utils/util_function.py
index c2b3dba..5c8a97b 100644
--- a/ptocr/utils/util_function.py
+++ b/ptocr/utils/util_function.py
@@ -10,7 +10,20 @@
import cv2
import torch
import numpy as np
-
+from PIL import Image
+
+def PILImageToCV(img,is_gray=False):
+ img = np.asarray(img)
+ if not is_gray:
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ return img
+
+def CVImageToPIL(img,is_gray=False):
+ if not is_gray:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = Image.fromarray(img)
+ return img
+
def create_module(module_str):
tmpss = module_str.split(",")
assert len(tmpss) == 2, "Error formate\
@@ -20,6 +33,22 @@ def create_module(module_str):
function = getattr(somemodule, function_name)
return function
+
+def restore_training(model_file,model,optimizer):
+ checkpoint = torch.load(model_file)
+ start_epoch = checkpoint['epoch']
+ try:
+ model.load_state_dict(checkpoint['state_dict'])
+ except:
+ state = model.state_dict()
+ for key in state.keys():
+ state[key] = checkpoint['state_dict'][key[7:]]
+ model.load_state_dict(state)
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ best_acc = checkpoint['best_acc']
+ return model,optimizer,start_epoch,best_acc
+
+
def resize_image_batch(img,algorithm,side_len=1536,add_padding=True):
if(algorithm=='SAST'):
@@ -71,7 +100,21 @@ def resize_image(img,algorithm,side_len=736,stride = 128):
resized_img = cv2.resize(img, (new_width, new_height))
return resized_img
+def resize_image_crnn(img,max_width=100,side_len=32,stride =4):
+ height, width, _ = img.shape
+ new_height = side_len
+
+ new_width = int(math.ceil(new_height / height * width / stride) * stride)
+ if new_width>max_width:
+ resized_img = cv2.resize(img, (max_width, new_height))
+ else:
+ resized_img = cv2.resize(img, (new_width, new_height))
+ resized_img = cv2.copyMakeBorder(resized_img, 0, 0,
+ 0, max_width-new_width, cv2.BORDER_CONSTANT, value=(0,0,0))
+ return resized_img
+
+
def save_checkpoint(state, checkpoint='checkpoint', filename='model.pth.tar'):
filepath = os.path.join(checkpoint, filename)
torch.save(state, filepath)
@@ -89,8 +132,13 @@ def loss_mean(self):
def loss_clear(self):
self.loss_items = []
-def create_process_obj(algorithm,pred):
+def create_process_obj(config,pred):
+ algorithm = config['base']['algorithm']
+
if(algorithm=='DB'):
+ if 'n_class' in config['base'].keys():
+ pred,pred_class = pred
+ return pred.cpu().numpy(),pred_class
return pred.cpu().numpy()
elif(algorithm=='SAST'):
pred['f_score'] = pred['f_score'].cpu().numpy()
@@ -102,10 +150,15 @@ def create_process_obj(algorithm,pred):
return pred
-def create_loss_bin(algorithm,use_distil=False,use_center=False):
+def create_loss_bin(config,use_distil=False,use_center=False):
+ algorithm = config['base']['algorithm']
bin_dict = {}
if(algorithm=='DB'):
- keys = ['loss_total','loss_l1', 'loss_bce', 'loss_thresh']
+ if 'n_class' in config['base'].keys():
+ keys = ['loss_total','loss_l1', 'loss_bce','loss_class', 'loss_thresh']
+ else:
+ keys = ['loss_total','loss_l1', 'loss_bce', 'loss_thresh']
+
elif(algorithm=='PSE'):
keys = ['loss_total','loss_kernel', 'loss_text']
elif(algorithm=='PAN'):
@@ -117,6 +170,8 @@ def create_loss_bin(algorithm,use_distil=False,use_center=False):
keys = ['loss_total','loss_ctc','loss_center']
else:
keys = ['loss_ctc']
+ elif (algorithm == 'FC'):
+ keys = ['loss_fc']
else:
assert 1==2,'only support algorithm DB,SAST,PSE,PAN,CRNN !!!'
@@ -144,14 +199,14 @@ def load_model(model,model_path):
if torch.cuda.is_available():
model_dict = torch.load(model_path)
else:
- model_dict = torch.load(model_path,map_location='cpu')
-
+ model_dict = torch.load(model_path,map_location='cpu')
if('state_dict' in model_dict.keys()):
model_dict = model_dict['state_dict']
try:
model.load_state_dict(model_dict)
except:
+
state = model.state_dict()
for key in state.keys():
state[key] = model_dict['module.' + key]
@@ -167,6 +222,21 @@ def merge_config(config,args):
return config
+def FreezeParameters(model,layer_name):
+ for name, parameter in model.named_parameters():
+ if layer_name in name:
+ parameter.requires_grad = False
+ return filter(lambda p: p.requires_grad, model.parameters())
+
+def ReleaseParameters(model,optimizer,layer_name):
+ for name, parameter in model.named_parameters():
+ parameter_dict = {}
+ if layer_name in name:
+ parameter.requires_grad = True
+ parameter_dict['params'] = parameter
+ optimizer.add_param_group(parameter_dict)
+ return optimizer
+
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
diff --git a/script/__pycache__/__init__.cpython-36.pyc b/script/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..04478db
Binary files /dev/null and b/script/__pycache__/__init__.cpython-36.pyc differ
diff --git a/script/__pycache__/onnx_to_tensorrt.cpython-36.pyc b/script/__pycache__/onnx_to_tensorrt.cpython-36.pyc
new file mode 100644
index 0000000..d7121cb
Binary files /dev/null and b/script/__pycache__/onnx_to_tensorrt.cpython-36.pyc differ
diff --git a/script/create_lmdb.py b/script/create_lmdb.py
index 51833ee..64e13eb 100644
--- a/script/create_lmdb.py
+++ b/script/create_lmdb.py
@@ -5,6 +5,7 @@
import argparse
import shutil
import sys
+from tqdm import tqdm
def checkImageIsValid(imageBin):
if imageBin is None:
@@ -54,7 +55,9 @@ def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkV
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 1
+ bar = tqdm(total=nSamples)
for i in range(nSamples):
+ bar.update(1)
imagePath = imagePathList[i]
label = labelList[i]
@@ -78,50 +81,38 @@ def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkV
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
- print('Written %d / %d' % (cnt, nSamples))
+# print('Written %d / %d' % (cnt, nSamples))
cnt += 1
+ bar.close()
nSamples = cnt-1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
env.close()
print('Created dataset with %d samples' % nSamples)
-def read_data_from_folder(folder_path):
- image_path_list = []
- label_list = []
- pics = os.listdir(folder_path)
- pics.sort(key = lambda i: len(i))
- for pic in pics:
- image_path_list.append(folder_path + '/' + pic)
- label_list.append(pic.split('_')[0])
- return image_path_list, label_list
def read_data_from_file(file_path):
image_path_list = []
label_list = []
- f = open(file_path)
- while True:
- line1 = f.readline()
- line2 = f.readline()
- if not line1 or not line2:
- break
- line1 = line1.replace('\r', '').replace('\n', '')
- line2 = line2.replace('\r', '').replace('\n', '')
- image_path_list.append(line1)
- label_list.append(line2)
+ with open(file_path,'r',encoding='utf-8') as fid:
+ lines = fid.readlines()
+ for line in lines:
+ line = line.replace('\r', '').replace('\n', '').split('\t')
+ image_path_list.append(line[0])
+ label_list.append(line[1])
return image_path_list, label_list
def show_demo(demo_number, image_path_list, label_list):
- print ('\nShow some demo to prevent creating wrong lmdb data')
print ('The first line is the path to image and the second line is the image label')
+ print ('###########################################################################')
for i in range(demo_number):
print ('image: %s\nlabel: %s\n' % (image_path_list[i], label_list[i]))
+ print ('###########################################################################')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--out', type = str, required = True, help = 'lmdb data output path')
- parser.add_argument('--folder', type = str, help = 'path to folder which contains the images')
parser.add_argument('--file', type = str, help = 'path to file which contains the image path and label')
args = parser.parse_args()
@@ -129,10 +120,7 @@ def show_demo(demo_number, image_path_list, label_list):
image_path_list, label_list = read_data_from_file(args.file)
createDataset(args.out, image_path_list, label_list)
show_demo(2, image_path_list, label_list)
- elif args.folder is not None:
- image_path_list, label_list = read_data_from_folder(args.folder)
- createDataset(args.out, image_path_list, label_list)
- show_demo(2, image_path_list, label_list)
+
else:
- print ('Please use --floder or --file to assign the input. Use -h to see more.')
+ print ('Please use --file to assign the input. Use -h to see more.')
sys.exit()
\ No newline at end of file
diff --git a/script/create_lmdb_multiprocessing.py b/script/create_lmdb_multiprocessing.py
new file mode 100644
index 0000000..3ed9b1d
--- /dev/null
+++ b/script/create_lmdb_multiprocessing.py
@@ -0,0 +1,206 @@
+import os
+import lmdb
+import cv2
+import numpy as np
+import argparse
+import shutil
+import sys
+from tqdm import tqdm
+import time
+import six
+from PIL import Image
+from multiprocessing import Process
+
+def get_dict(char_list):
+ char_dict={}
+ for item in char_list:
+ if item in char_dict.keys():
+ char_dict[item]+=1
+ else:
+ char_dict[item]=1
+ return char_dict
+
+def checklmdb(args):
+ env = lmdb.open(
+ args.out,
+ max_readers=2,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False)
+
+ with env.begin(write=False) as txn:
+ nSamples = int(txn.get('num-samples'.encode('utf-8')))
+ print('Check lmdb ok!!!')
+ print('lmdb Have {} samples'.format(nSamples))
+ print('Print 5 samples:')
+
+ for index in range(1,nSamples+1):
+ if index>5:
+ break
+ img_key = 'image-%09d' % index
+ imgbuf = txn.get(img_key.encode('utf-8'))
+ label_key = 'label-%09d' % index
+ label = txn.get(label_key.encode('utf-8')).decode()
+ print(index,label)
+
+
+def checkImageIsValid(imageBin):
+ if imageBin is None:
+ return False
+
+ try:
+ imageBuf = np.fromstring(imageBin, dtype=np.uint8)
+ img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
+ imgH, imgW = img.shape[0], img.shape[1]
+ except:
+ return False
+ else:
+ if imgH * imgW == 0:
+ return False
+
+ return True
+
+
+def writeCache(env, cache):
+ with env.begin(write=True) as txn:
+ for k, v in cache.items():
+ if type(k) == str:
+ k = k.encode()
+ if type(v) == str:
+ v = v.encode()
+ txn.put(k,v)
+
+
+def write(imagePathList,labelList,env,start,end):
+ cache = {}
+ cnt = 1
+ bar = tqdm(total=end-start)
+ checkValid=True
+ lexiconList=None
+ for i in range(end-start):
+ bar.update(1)
+ imagePath = imagePathList[start+i]
+ label = labelList[start+i]
+
+ if not os.path.exists(imagePath):
+ print('%s does not exist' % imagePath)
+ continue
+ with open(imagePath, 'rb') as f:
+ imageBin = f.read()
+ if checkValid:
+ if not checkImageIsValid(imageBin):
+ print('%s is not a valid image' % imagePath)
+ continue
+
+ imageKey = 'image-%09d' % (start+cnt)
+ labelKey = 'label-%09d' % (start+cnt)
+
+ cache[imageKey] = imageBin
+ cache[labelKey] = label
+ if lexiconList:
+ lexiconKey = 'lexicon-%09d' % cnt
+ cache[lexiconKey] = ' '.join(lexiconList[i])
+ if (start+cnt) % 1000 == 0:
+ writeCache(env, cache)
+ cache = {}
+ cnt += 1
+ bar.close()
+ writeCache(env, cache)
+
+def createDataset(outputPath, imagePathList, labelList, num=1,lexiconList=None, checkValid=True):
+ """
+ Create LMDB dataset for CRNN training.
+ ARGS:
+ outputPath : LMDB output path
+ imagePathList : list of image path
+ labelList : list of corresponding groundtruth texts
+ lexiconList : (optional) list of lexicon lists
+ checkValid : if true, check the validity of every image
+ """
+ # If lmdb file already exists, remove it. Or the new data will add to it.
+ if os.path.exists(outputPath):
+ shutil.rmtree(outputPath)
+ os.makedirs(outputPath)
+ else:
+ os.makedirs(outputPath)
+
+ assert (len(imagePathList) == len(labelList))
+ nSamples = len(imagePathList)
+ env = lmdb.open(outputPath, map_size=1099511627776)
+ cache = {}
+ index = []
+
+ if nSamples%num==0:
+ step = nSamples//num
+ for i in range(num):
+ index.append([i*step,(i+1)*step])
+ else:
+ step = nSamples//num
+ for i in range(num):
+ index.append([i*step,(i+1)*step])
+ index.append([num*step,nSamples])
+
+ p_list = []
+ if nSamples%num==0:
+ for i in range(num):
+ p = Process(target=write,args=(imagePathList,labelList,env,index[i][0],index[i][1]))
+ p_list.append(p)
+ p.start()
+ else:
+ for i in range(num+1):
+ p = Process(target=write,args=(imagePathList,labelList,env,index[i][0],index[i][1]))
+ p_list.append(p)
+ p.start()
+ for p in p_list:
+ p.join()
+ cache['num-samples'] = str(nSamples)
+ writeCache(env, cache)
+ env.close()
+ print('Created dataset with %d samples' % nSamples)
+
+
+def read_data_from_file(file_path):
+ image_path_list = []
+ label_list = []
+ with open(file_path,'r',encoding='utf-8') as fid:
+ lines = fid.readlines()
+ for line in lines:
+ line = line.replace('\r', '').replace('\n', '').split('\t')
+ image_path_list.append(line[0])
+ label_list.append(line[1])
+
+ return image_path_list, label_list
+
+def show_demo(demo_number, image_path_list, label_list):
+ print ('The first line is the path to image and the second line is the image label')
+ print ('###########################################################################')
+ for i in range(demo_number):
+ print ('image: %s\nlabel: %s\n' % (image_path_list[i], label_list[i]))
+ print ('###########################################################################')
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--out', type = str, required = True, help = 'lmdb data output path')
+ parser.add_argument('--file', type = str,required = True, help = 'path to file which contains the image path and label')
+ parser.add_argument('--num_process', type = int, required = True, help = 'num_process to do')
+ args = parser.parse_args()
+
+ if args.file is not None:
+ image_path_list, label_list = read_data_from_file(args.file)
+ show_demo(2, image_path_list, label_list)
+ s_time = time.time()
+ createDataset(args.out, image_path_list, label_list,num=args.num_process)
+ print('cost_time:'+str(time.time()-s_time)+'s')
+ else:
+ print ('Please use --file to assign the input. Use -h to see more.')
+ sys.exit()
+ print('lmdb generate ok!!!!')
+ checklmdb(args)
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/script/get_train_list.py b/script/get_train_list.py
index 6e5872d..0b481b9 100644
--- a/script/get_train_list.py
+++ b/script/get_train_list.py
@@ -26,4 +26,5 @@ def gen_train_file(args):
parser.add_argument('--label_path', nargs='?', type=str, default=None)
parser.add_argument('--img_path', nargs='?', type=str, default=None)
parser.add_argument('--save_path', nargs='?', type=str, default=None)
- args = parser.parse_args()
\ No newline at end of file
+ args = parser.parse_args()
+ gen_train_file(args)
\ No newline at end of file
diff --git a/script/warp_polar.py b/script/warp_polar.py
new file mode 100644
index 0000000..00399a6
--- /dev/null
+++ b/script/warp_polar.py
@@ -0,0 +1,67 @@
+#-*- coding:utf-8 _*-
+"""
+@author:fxw
+@file: tt.py
+@time: 2020/12/25
+"""
+import cv2
+import numpy as np
+import sys
+
+#实现图像的极坐标的转换 center代表及坐标变换中心‘;r是一个二元元组,代表最大与最小的距离;theta代表角度范围
+#rstep代表步长; thetastap代表角度的变化步长
+def polar(image,center,r,theta=(70,360+70),rstep=0.8,thetastep=360.0/(360*2)):
+ #得到距离的最小值、最大值
+ minr,maxr=r
+ #角度的最小范围
+ mintheta,maxtheta=theta
+ #输出图像的高、宽 O:指定形状类型的数组float64
+ H=int((maxr-minr)/rstep)+1
+ W=int((maxtheta-mintheta)/thetastep)+1
+ O=125*np.ones((H,W,3),image.dtype)
+ #极坐标转换 利用tile函数实现W*1铺成的r个矩阵 并对生成的矩阵进行转置
+ r=np.linspace(minr,maxr,H)
+ r=np.tile(r,(W,1))
+ r=np.transpose(r)
+ theta=np.linspace(mintheta,maxtheta,W)
+ theta=np.tile(theta,(H,1))
+ x,y=cv2.polarToCart(r,theta,angleInDegrees=True)
+ #最近插值法
+ for i in range(H):
+ for j in range(W):
+ px=int(round(x[i][j])+cx)
+ py=int(round(y[i][j])+cy)
+ if((px>=0 and px<=w-1) and (py>=0 and py<=h-1)):
+ O[i][j][0]=image[py][px][0]
+ O[i][j][1]=image[py][px][1]
+ O[i][j][2]=image[py][px][2]
+
+ return O
+
+import time
+if __name__=="__main__":
+ img = cv2.imread(r"C:\Users\fangxuwei\Desktop\111.jpg")
+ # 传入的图像宽:600 高:400
+ h, w = img.shape[:2]
+ print("h:%s w:%s"%(h,w))
+ # 极坐标的变换中心(300,200)
+ # cx, cy = h//2, w//2
+ cx, cy = 204, 201
+ # cx, cy = 131, 123
+ # 圆的半径为10 颜色:灰 最小位数3
+ cv2.circle(img, (int(cx), int(cy)), 10, (255, 0, 0, 0), 3)
+ s = time.time()
+ L = polar(img, (cx, cy), (h//3, w//2))
+ # 旋转
+ L = cv2.flip(L, 0)
+ print(time.time()-s)
+
+ # 显示与输出
+ cv2.imshow('img', img)
+ cv2.imshow('O', L)
+ cv2.waitKey(0)
+
+
+
+
+
diff --git a/tools/__pycache__/MarginLoss.cpython-36.pyc b/tools/__pycache__/MarginLoss.cpython-36.pyc
new file mode 100644
index 0000000..34f1559
Binary files /dev/null and b/tools/__pycache__/MarginLoss.cpython-36.pyc differ
diff --git a/tools/__pycache__/__init__.cpython-36.pyc b/tools/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..f3c0dad
Binary files /dev/null and b/tools/__pycache__/__init__.cpython-36.pyc differ
diff --git a/tools/cal_rescall/__pycache__/__init__.cpython-36.pyc b/tools/cal_rescall/__pycache__/__init__.cpython-36.pyc
index a36e979..22329de 100644
Binary files a/tools/cal_rescall/__pycache__/__init__.cpython-36.pyc and b/tools/cal_rescall/__pycache__/__init__.cpython-36.pyc differ
diff --git a/tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc b/tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc
index 9b788ab..1ba789f 100644
Binary files a/tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc and b/tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc differ
diff --git a/tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc b/tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc
index bca5b29..152a5c3 100644
Binary files a/tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc and b/tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc differ
diff --git a/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc b/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc
index 66f438a..e52d2ad 100644
Binary files a/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc and b/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc differ
diff --git a/tools/cal_rescall/__pycache__/script.cpython-36.pyc b/tools/cal_rescall/__pycache__/script.cpython-36.pyc
index 6eebef9..8a8f59f 100644
Binary files a/tools/cal_rescall/__pycache__/script.cpython-36.pyc and b/tools/cal_rescall/__pycache__/script.cpython-36.pyc differ
diff --git a/tools/det_infer.py b/tools/det_infer.py
index bc9348e..94986d4 100644
--- a/tools/det_infer.py
+++ b/tools/det_infer.py
@@ -16,6 +16,7 @@
import numpy as np
from tqdm import tqdm
import onnxruntime
+import torch.nn.functional as F
import torchvision.transforms as transforms
from ptocr.utils.util_function import create_module,resize_image_batch
from ptocr.utils.util_function import create_process_obj,create_dir,load_model
@@ -115,19 +116,28 @@ def infer_img(self,ori_imgs):
out = output[0].reshape(int(args.batch_size),7,test_size,test_size)
out = torch.Tensor(out)
-
-# import pdb
-# pdb.set_trace()
-
+
if isinstance(out,dict):
img_num = out['f_score'].shape[0]
+ elif isinstance(out,tuple):
+ img_num = out[0].shape[0]
else:
img_num = out.shape[0]
if(self.config['base']['algorithm']=='SAST'):
- scales = [(scale[0],scale[1],ori_imgs[i].shape[0],ori_imgs[i].shape[1]) for scale in scales]
+ scales = [(1.0/scales[i][1],1.0/scales[i][0],ori_imgs[i].shape[0],ori_imgs[i].shape[1]) for i in range(len(scales))]
- out = create_process_obj(self.config['base']['algorithm'],out)
+ out = create_process_obj(self.config,out)
+
+ if 'n_class' in self.config['base'].keys():
+ out,out_class = out
+ b,_,w,h = out_class.shape
+ out_class = out_class.permute(0,2,3,1).reshape(-1,self.config['base']['n_class'])
+ out_class = F.softmax(out_class,-1)
+ out_class = out_class.max(1)[1].reshape(b,w,h).unsqueeze(1)
+ bbox_batch, score_batch,class_batch = self.img_process(out,out_class.cpu().numpy(), scales)
+ return bbox_batch,score_batch,class_batch
+
bbox_batch, score_batch = self.img_process(out, scales)
return bbox_batch,score_batch
@@ -145,6 +155,28 @@ def InferOneImg(bin,img,image_name,save_path):
fid_res.write(bbox_str)
cv2.imwrite(os.path.join(save_path, 'result_img', image_name[i] + '.jpg'), img_show)
+color_rgb = [(0,0,255),(0,255,0),(255,0,0),(255,255,0),(255,0,255),(0,255,255),(156,102,31),(255,192,203),(160,32,240),(115,74,18)]
+
+def InferOneImgMul(bin,img,image_name,save_path,n_class):
+ bbox_batch, score_batch,class_batch = bin.infer_img(img)
+ colors = {}
+ for i in range(1,n_class+1):
+ colors[str(i-1)] = color_rgb[i-1]
+ for i in range(len(bbox_batch)):
+ img_show = img[i].copy()
+ with open(os.path.join(save_path, 'result_txt', 'res_' + image_name[i] + '.txt'), 'w+', encoding='utf-8') as fid_res:
+ bboxes = bbox_batch[i]
+ classes = class_batch[i]
+ tag = 0
+ for bbox in bboxes:
+ bbox = bbox.reshape(-1, 2).astype(np.int)
+ img_show = cv2.drawContours(img_show, [bbox], -1, colors[str(int(classes[tag]))], 2)
+ bbox_str = [str(x) for x in bbox.reshape(-1)]
+ bbox_str = ','.join(bbox_str) + '\n'
+ fid_res.write(bbox_str)
+ tag+=1
+ cv2.imwrite(os.path.join(save_path, 'result_img', image_name[i] + '.jpg'), img_show)
+
def InferImage(config):
path = config['infer']['path']
save_path = config['infer']['save_path']
@@ -159,7 +191,10 @@ def InferImage(config):
bar = tqdm(total=len(batch_imgs))
for i in range(len(batch_imgs)):
bar.update(1)
- InferOneImg(test_bin, batch_imgs[i],batch_img_names[i], save_path)
+ if 'n_class' in config['base'].keys():
+ InferOneImgMul(test_bin, batch_imgs[i],batch_img_names[i], save_path,config['base']['n_class'])
+ else:
+ InferOneImg(test_bin, batch_imgs[i],batch_img_names[i], save_path)
bar.close()
else:
diff --git a/tools/det_sast.py b/tools/det_sast.py
index 82087a4..7bd8eb1 100644
--- a/tools/det_sast.py
+++ b/tools/det_sast.py
@@ -192,7 +192,7 @@ def TrainValProgram(config):
log_write.set_names(title)
for epoch in range(start_epoch,config['base']['n_epoch']):
- train_dataset.gen_train_img()
+# train_dataset.gen_train_img()
model.train()
optimizer_decay(config, optimizer, epoch)
loss_write = ModelTrain(train_data_loader, model, criterion, optimizer, loss_bin, config, epoch)
@@ -231,8 +231,8 @@ def TrainValProgram(config):
if __name__ == "__main__":
- stream = open('./config/det_SAST_resnet50_ori_dataload.yaml', 'r', encoding='utf-8')
-# stream = open('./config/det_SAST_resnet50_3*3_ori_dataload.yaml', 'r', encoding='utf-8')
+# stream = open('./config/det_SAST_resnet50_ori_dataload.yaml', 'r', encoding='utf-8')
+ stream = open('./config/det_SAST_resnet50_3_3_ori_dataload.yaml', 'r', encoding='utf-8')
config = yaml.load(stream,Loader=yaml.FullLoader)
TrainValProgram(config)
\ No newline at end of file
diff --git a/tools/det_train.py b/tools/det_train.py
index d22ad67..d8ee729 100644
--- a/tools/det_train.py
+++ b/tools/det_train.py
@@ -165,7 +165,7 @@ def TrainValProgram(args):
criterion = create_module(config['architectures']['loss_function'])(config)
train_dataset = create_module(config['trainload']['function'])(config)
test_dataset = create_module(config['testload']['function'])(config)
- optimizer = create_module(config['optimizer']['function'])(config, model)
+ optimizer = create_module(config['optimizer']['function'])(config, model.parameters())
optimizer_decay = create_module(config['optimizer_decay']['function'])
img_process = create_module(config['postprocess']['function'])(config)
@@ -195,7 +195,7 @@ def TrainValProgram(args):
use_distil = False
if args.t_config is not None:
use_distil = True
- loss_bin = create_loss_bin(config['base']['algorithm'],use_distil)
+ loss_bin = create_loss_bin(config,use_distil)
if torch.cuda.is_available():
if (len(config['base']['gpu_id'].split(',')) > 1):
diff --git a/tools/rec_infer.py b/tools/rec_infer.py
index c5cecac..31a505e 100644
--- a/tools/rec_infer.py
+++ b/tools/rec_infer.py
@@ -5,6 +5,7 @@
@time: 2020/08/20
"""
import os
+os.environ["CUDA_VISIBLE_DEVICES"] = '2'
import sys
sys.path.append('./')
import cv2
@@ -14,8 +15,24 @@
import numpy as np
from tqdm import tqdm
import torchvision.transforms as transforms
-from ptocr.utils.util_function import create_module,resize_image
+from ptocr.utils.util_function import create_module,resize_image,resize_image_crnn
from ptocr.utils.util_function import create_process_obj,create_dir,load_model
+from get_acc_english import get_test_acc
+
+def is_chinese(check_str):
+ flag = False
+ for ch in check_str.decode('utf-8'):
+ if u'\u4e00' >= ch or ch >= u'\u9fff':
+ flag = True
+ return flag
+
+def is_number_letter(check_str):
+ if(check_str in '0123456789' or check_str in 'abcdefghijklmnopgrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'):
+ return True
+ else:
+ return False
+
+
class TestProgram():
@@ -26,6 +43,7 @@ def __init__(self,config):
config['base']['classes'] = len(self.converter.alphabet)
model = create_module(config['architectures']['model_function'])(config)
model = load_model(model,config['infer']['model_path'])
+
if torch.cuda.is_available():
model = model.cuda()
self.model = model
@@ -33,39 +51,96 @@ def __init__(self,config):
self.model.eval()
def infer_img(self,ori_img):
- img = resize_image(ori_img,self.congig['base']['algorithm'],32)
+
+# img = resize_image(ori_img,self.congig['base']['algorithm'],32,4)
+ img = resize_image_crnn(ori_img)
+# cv2.imwrite('result.jpg',img)
+ if(img.shape[0]!=self.congig['base']['img_shape'][0]):
+ return ''
+
img = Image.fromarray(img).convert('RGB')
if(self.congig['base']['is_gray']):
img = img.convert('L')
+# img = img.resize((100, 32),Image.ANTIALIAS)
+ image_for_show = np.array(img.convert('RGB')).copy()
img = transforms.ToTensor()(img)
img.sub_(0.5).div_(0.5)
img = img.unsqueeze(0)
+
if torch.cuda.is_available():
img = img.cuda()
with torch.no_grad():
- preds = self.model(img)
+ preds,feau= self.model(img)
+ preds = preds.permute(1, 0, 2)
+
+ time_step = preds.shape[0]
+ image_width = img.shape[3]
+ step_ = image_width//time_step
+
+ ### 输出相似字结果
+# k_num = 3
+# p,idx = torch.softmax(preds,-1).squeeze().topk(k_num,-1)
+# p = p[idx[:,0]>0].cpu().numpy()
+# for i in range(k_num):
+# index = idx[:,i][idx[:,0]>0]
+# result = np.array(list(self.converter.alphabet))[index.cpu().numpy()-1].tolist()
+# if(i==0):
+# print('识别结果:',result,p[:,i])
+# else:
+# print('相似字:',result,p[:,i])
+
preds_size = torch.IntTensor([preds.size(0)])
_, preds = preds.max(2)
preds = preds.squeeze(1)
preds = preds.contiguous().view(-1)
+
sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False)
+# raw_pred = self.converter.decode(preds.data, preds_size.data, raw=True)
+# start_step = 0
+# word_po = []
+# for i in range(len(raw_pred)-1):
+# if(raw_pred[i]!='-' and raw_pred[i]!=raw_pred[i+1]):
+# # image_for_show = cv2.rectangle(image_for_show,(start_step+step_,0),(start_step+step_, img.shape[2]-1),(0,0,255))
+# start_step+=step_
+# word_po.append(start_step+step_)
+# else:
+# start_step+=step_
+
+# # word_po = np.array(word_po)
+# # word_po_flag = np.zeros(len(word_po))
+# # word_po_flag[1:] = word_po[:-1]
+# # word_len = sorted(word_po-word_po_flag)[len(word_po)//2]
+# word_s_e = []
+# for i in range(len(word_po)):
+# if(is_number_letter(sim_preds[i])):
+# word_len = self.congig['base']['img_shape'][0]//4
+# else:
+# word_len = self.congig['base']['img_shape'][0]//2
+# word_s_e.append([word_po[i]-word_len,word_po[i]+word_len])
+# # image_for_show = cv2.rectangle(image_for_show,(int(word_po[i]-word_len),1),(int(word_po[i]+word_len), img.shape[2]-1),(255,0,255))
+# image_for_show = cv2.line(image_for_show,(int(word_po[i]),1),(int(word_po[i]), img.shape[2]-1),(255,0,255))
+# cv2.imwrite('result.jpg',image_for_show)
+
return sim_preds
def InferImage(config):
path = config['infer']['path']
save_path = config['infer']['save_path']
test_bin = TestProgram(config)
+ fid = open('re_result.txt','w+',encoding='utf-8')
if os.path.isdir(path):
files = os.listdir(path)
bar = tqdm(total=len(files))
for file in files:
+ print(file)
bar.update(1)
image_name = file.split('.')[0]
img_path = os.path.join(path,file)
img = cv2.imread(img_path)
rec_char = test_bin.infer_img(img)
+ fid.write(file+'\t'+rec_char+'\n')
print(rec_char)
bar.close()
@@ -74,9 +149,51 @@ def InferImage(config):
img = cv2.imread(path)
rec_char = test_bin.infer_img(img)
print(rec_char)
+
+def InferImageAcc(config):
+ root_path = config['infer']['path']
+ save_path = config['infer']['save_path']
+ test_bin = TestProgram(config)
+ acc_fid = open('acc.txt','w+',encoding='utf-8')
+ acc_all = 0
+ for _dir in os.listdir(root_path):
+ path = os.path.join(root_path,_dir,'image')
+ gt_file = os.path.join(root_path,_dir,'val.txt')
+# print(gt_file)
+ fid = open(_dir+'.txt','w+',encoding='utf-8')
+ if os.path.isdir(path):
+ files = os.listdir(path)
+ bar = tqdm(total=len(files))
+ for file in files:
+# print(file)
+ bar.update(1)
+ image_name = file.split('.')[0]
+ img_path = os.path.join(path,file)
+ img = cv2.imread(img_path)
+ rec_char = test_bin.infer_img(img)
+ fid.write(file+'\t'+rec_char+'\n')
+# print(rec_char)
+ bar.close()
+ else:
+ image_name = path.split('/')[-1].split('.')[0]
+ img = cv2.imread(path)
+ rec_char = test_bin.infer_img(img)
+ print(rec_char)
+ fid.close()
+ acc = get_test_acc(_dir+'.txt',gt_file)
+ acc_all+=acc
+ acc_fid.write(_dir+':'+'\t'+str(acc)+'\n')
+ print('mean acc:',acc_all/9.0)
+ acc_fid.close()
+
if __name__ == "__main__":
- stream = open('./config/rec_CRNN_ori.yaml', 'r', encoding='utf-8')
+
+ stream = open('./config/rec_CRNN_mobilev3_small_english_all.yaml', 'r', encoding='utf-8')
config = yaml.load(stream,Loader=yaml.FullLoader)
- InferImage(config)
\ No newline at end of file
+ InferImageAcc(config)
+
+# stream = open('./config/rec_CRNN_resnet_english.yaml', 'r', encoding='utf-8')
+# config = yaml.load(stream,Loader=yaml.FullLoader)
+# InferImage(config)
\ No newline at end of file
diff --git a/tools/rec_infer_bk1.py b/tools/rec_infer_bk1.py
new file mode 100644
index 0000000..b1cd2a7
--- /dev/null
+++ b/tools/rec_infer_bk1.py
@@ -0,0 +1,159 @@
+#-*- coding:utf-8 _*-
+"""
+@author:fxw
+@file: det_infer.py
+@time: 2020/08/20
+"""
+import os
+os.environ["CUDA_VISIBLE_DEVICES"] = '2'
+import sys
+sys.path.append('./')
+import cv2
+import torch
+import yaml
+from PIL import Image
+import numpy as np
+from tqdm import tqdm
+import torchvision.transforms as transforms
+from ptocr.utils.util_function import create_module,resize_image,resize_image_crnn
+from ptocr.utils.util_function import create_process_obj,create_dir,load_model
+from get_acc_english import get_test_acc
+
+def is_chinese(check_str):
+ flag = False
+ for ch in check_str.decode('utf-8'):
+ if u'\u4e00' >= ch or ch >= u'\u9fff':
+ flag = True
+ return flag
+
+def is_number_letter(check_str):
+ if(check_str in '0123456789' or check_str in 'abcdefghijklmnopgrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'):
+ return True
+ else:
+ return False
+
+
+
+
+class TestProgram():
+ def __init__(self,config):
+ super(TestProgram,self).__init__()
+
+ self.converter = create_module(config['label_transform']['function'])(config)
+ model = create_module(config['architectures']['model_function'])(config)
+ model = load_model(model,config['infer']['model_path'])
+
+ if torch.cuda.is_available():
+ model = model.cuda()
+ self.model = model
+ self.congig = config
+ self.model.eval()
+
+ def infer_img(self,ori_img):
+
+# img = resize_image(ori_img,self.congig['base']['algorithm'],32,4)
+ img = resize_image_crnn(ori_img)
+# cv2.imwrite('result.jpg',img)
+ if(img.shape[0]!=self.congig['base']['img_shape'][0]):
+ return ''
+
+ img = Image.fromarray(img).convert('RGB')
+ if(self.congig['base']['is_gray']):
+ img = img.convert('L')
+# img = img.resize((100, 32),Image.ANTIALIAS)
+ image_for_show = np.array(img.convert('RGB')).copy()
+ img = transforms.ToTensor()(img)
+ img.sub_(0.5).div_(0.5)
+ img = img.unsqueeze(0)
+
+ if torch.cuda.is_available():
+ img = img.cuda()
+
+ with torch.no_grad():
+ preds,feau= self.model(img)
+
+ time_step = preds.shape[0]
+ image_width = img.shape[3]
+ step_ = image_width//time_step
+
+
+ _, preds = preds.max(2)
+
+
+ sim_preds = self.converter.decode(preds.data)
+ print(preds)
+ print(sim_preds)
+ return sim_preds[0]
+
+def InferImage(config):
+ path = config['infer']['path']
+ save_path = config['infer']['save_path']
+ test_bin = TestProgram(config)
+ fid = open('re_result.txt','w+',encoding='utf-8')
+ if os.path.isdir(path):
+ files = os.listdir(path)
+ bar = tqdm(total=len(files))
+ for file in files:
+ print(file)
+ bar.update(1)
+ image_name = file.split('.')[0]
+ img_path = os.path.join(path,file)
+ img = cv2.imread(img_path)
+ rec_char = test_bin.infer_img(img)
+ fid.write(file+'\t'+rec_char+'\n')
+ print(rec_char)
+ bar.close()
+
+ else:
+ image_name = path.split('/')[-1].split('.')[0]
+ img = cv2.imread(path)
+ rec_char = test_bin.infer_img(img)
+ print(rec_char)
+
+def InferImageAcc(config):
+ root_path = config['infer']['path']
+ save_path = config['infer']['save_path']
+ test_bin = TestProgram(config)
+ acc_fid = open('acc.txt','w+',encoding='utf-8')
+ acc_all = 0
+ for _dir in os.listdir(root_path):
+ path = os.path.join(root_path,_dir,'image')
+ gt_file = os.path.join(root_path,_dir,'val.txt')
+# print(gt_file)
+ fid = open(_dir+'.txt','w+',encoding='utf-8')
+ if os.path.isdir(path):
+ files = os.listdir(path)
+ bar = tqdm(total=len(files))
+ for file in files:
+# print(file)
+ bar.update(1)
+ image_name = file.split('.')[0]
+ img_path = os.path.join(path,file)
+ img = cv2.imread(img_path)
+ rec_char = test_bin.infer_img(img)
+ fid.write(file+'\t'+rec_char+'\n')
+# print(rec_char)
+ bar.close()
+
+ else:
+ image_name = path.split('/')[-1].split('.')[0]
+ img = cv2.imread(path)
+ rec_char = test_bin.infer_img(img)
+ print(rec_char)
+ fid.close()
+ acc = get_test_acc(_dir+'.txt',gt_file)
+ acc_all+=acc
+ acc_fid.write(_dir+':'+'\t'+str(acc)+'\n')
+ print('mean acc:',acc_all/9.0)
+ acc_fid.close()
+
+
+if __name__ == "__main__":
+
+ stream = open('./config/rec_FC_resnet_english_all.yaml', 'r', encoding='utf-8')
+ config = yaml.load(stream,Loader=yaml.FullLoader)
+ InferImageAcc(config)
+
+# stream = open('./config/rec_CRNN_resnet_english.yaml', 'r', encoding='utf-8')
+# config = yaml.load(stream,Loader=yaml.FullLoader)
+# InferImage(config)
\ No newline at end of file
diff --git a/tools/rec_train.py b/tools/rec_train.py
index 73c7c41..b11ea09 100644
--- a/tools/rec_train.py
+++ b/tools/rec_train.py
@@ -1,9 +1,3 @@
-# -*- coding:utf-8 _*-
-"""
-@author:fxw
-@file: det_train.py
-@time: 2020/08/07
-"""
import sys
sys.path.append('./')
import cv2
@@ -21,173 +15,114 @@
from ptocr.utils.util_function import create_module, create_loss_bin, \
set_seed,save_checkpoint,create_dir
from ptocr.utils.metrics import runningScore
-from ptocr.utils.logger import Logger
-from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter
+from ptocr.utils.logger import Logger,TrainLog
+from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter,restore_training
from ptocr.dataloader.RecLoad.CRNNProcess import alignCollate
import copy
-GLOBAL_WORKER_ID = None
+### 设置随机种子
GLOBAL_SEED = 2020
-
-
torch.manual_seed(GLOBAL_SEED)
torch.cuda.manual_seed(GLOBAL_SEED)
torch.cuda.manual_seed_all(GLOBAL_SEED)
np.random.seed(GLOBAL_SEED)
random.seed(GLOBAL_SEED)
-
-
-def worker_init_fn(worker_id):
- global GLOBAL_WORKER_ID
- GLOBAL_WORKER_ID = worker_id
- set_seed(GLOBAL_SEED + worker_id)
-
-# def backward_hook(self,grad_input, grad_output):
-# for g in grad_input:
-# g[g != g] = 0 # replace all nan/inf in gradients to zero
-
-
-def ModelTrain(train_data_loader,LabelConverter,model,center_model, criterion, optimizer,center_criterion,optimizer_center,center_flag,loss_bin, config, epoch):
- batch_time = AverageMeter()
- end = time.time()
-
- for batch_idx, data in enumerate(train_data_loader):
-# model.register_backward_hook(backward_hook)
- if(data is None):
- continue
- imgs,labels = data
+def ModelTrain(train_data_loader,LabelConverter, model,criterion,optimizer, train_log,loss_dict, config, epoch):
+ for batch_idx, ( imgs,labels) in enumerate(train_data_loader):
pre_batch = {}
gt_batch = {}
if torch.cuda.is_available():
imgs = imgs.cuda()
- preds = model(imgs)
-
+ preds,feau = model(imgs)
+ preds = preds.permute(1, 0, 2)
+
+ #########
labels,labels_len = LabelConverter.encode(labels,preds.size(0))
preds_size = Variable(torch.IntTensor([preds.size(0)] * config['trainload']['batch_size']))
pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size
gt_batch['labels'],gt_batch['labels_len'] = labels,labels_len
+ #########
+
+ if config['loss']['use_ctc_weight']:
+ len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0
+ len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0)
+ ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float()
+ ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0
+ gt_batch['ctc_loss_weight'] = ctc_loss_weight
+
+ loss = criterion(pre_batch, gt_batch).cuda()
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
- ctc_loss = criterion(pre_batch, gt_batch).cuda()
metrics = {}
- metrics['loss_total'] = 0.0
- metrics['loss_center'] = 0.0
- if center_criterion is not None and center_flag is True:
- center_model.eval()
- #####
- feautures = preds.clone()
- with torch.no_grad():
- center_preds = center_model(imgs)
-
- center_preds = torch.softmax(center_preds,-1)
- confs, center_preds = center_preds.max(2)
- center_preds = center_preds.squeeze(1).transpose(1, 0).contiguous()
- confs = confs.transpose(1, 0).contiguous()
-
-# confs = []
-# for i in range(center_preds.shape[0]):
-# conf = []
-# for j in range(len(center_preds[i])):
-# conf.append(probs[i,j,center_preds[i][j]])
-# confs.append(conf)
-# confs = torch.Tensor(confs).cuda()
-
- b,t = center_preds.shape
- feautures = feautures.transpose(1, 0).contiguous()
-
- confs = confs.view(-1)
- center_preds = center_preds.view(-1)
- feautures = feautures.view(b*t,-1)
-
-
- index = (center_preds>0) & (confs>config['loss']['label_score'])
- center_preds = center_preds[index]
- feautures = feautures[index]
-
- center_loss = center_criterion(feautures,center_preds)*config['loss']['weight_center']
-
- loss = ctc_loss + center_loss
-
-
- metrics['loss_total'] = loss.item()
- metrics['loss_ctc'] = ctc_loss.item()
- metrics['loss_center'] = center_loss.item()
-
- #####
- optimizer_center.zero_grad()
- optimizer.zero_grad()
-
- loss.backward()
- optimizer.step()
-
- for param in center_criterion.parameters():
- param.grad.data *= (1. / config['loss']['weight_center'])
- optimizer_center.step()
- else:
- loss = ctc_loss
- metrics['loss_ctc'] = ctc_loss.item()
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
+ metrics['ctc_loss'] = loss.item()
+
+ for key in loss_dict.keys():
+ loss_dict[key].update(metrics[key])
- for key in loss_bin.keys():
- loss_bin[key].loss_add(metrics[key])
- batch_time.update(time.time() - end)
- end = time.time()
if (batch_idx % config['base']['show_step'] == 0):
log = '({}/{}/{}/{}) | ' \
.format(epoch, config['base']['n_epoch'], batch_idx, len(train_data_loader))
- bin_keys = list(loss_bin.keys())
-
- for i in range(len(bin_keys)):
- log += bin_keys[i] + ':{:.4f}'.format(loss_bin[bin_keys[i]].loss_mean()) + ' | '
- log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])+ ' | '
- log+='batch_time:{:.2f} s'.format(batch_time.avg)+ ' | '
- log+='total_time:{:.2f} min'.format(batch_time.avg * batch_idx / 60.0)+ ' | '
- log+='ETA:{:.2f} min'.format(batch_time.avg*(len(train_data_loader)-batch_idx)/60.0)
- print(log)
- loss_write = []
- for key in list(loss_bin.keys()):
- loss_write.append(loss_bin[key].loss_mean())
- return loss_write,loss_bin['loss_ctc'].loss_mean()
-
+ keys = list(loss_dict.keys())
+ for i in range(len(keys)):
+ log += keys[i] + ':{:.4f}'.format(loss_dict[keys[i]].avg) + ' | '
+ log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])
+ train_log.info(log)
-def ModelEval(test_data_loader,LabelConverter,model,criterion,config):
- bar = tqdm(total=len(test_data_loader))
- loss_avg = []
- n_correct = 0
- for batch_idx, (imgs, labels) in enumerate(test_data_loader):
+def ModelEval(val_data_loader,LabelConverter, model,criterion,train_log,loss_dict,config):
+
+ bar = tqdm(total=len(val_data_loader))
+ val_loss = AverageMeter()
+ n_correct = AverageMeter()
+
+ for batch_idx, (imgs, labels) in enumerate(val_data_loader):
bar.update(1)
+
pre_batch = {}
gt_batch = {}
+
if torch.cuda.is_available():
- imgs = imgs.cuda()
+ imgs = imgs.cuda()
+
with torch.no_grad():
- preds = model(imgs)
+ preds,feau = model(imgs)
+ preds = preds.permute(1, 0, 2)
+
labels_class, labels_len = LabelConverter.encode(labels,preds.size(0))
- preds_size = Variable(torch.IntTensor([preds.size(0)] * config['testload']['batch_size']))
+ preds_size = Variable(torch.IntTensor([preds.size(0)] * config['valload']['batch_size']))
pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size
gt_batch['labels'],gt_batch['labels_len'] = labels_class,labels_len
+
+ if config['loss']['use_ctc_weight']:
+ len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0
+ len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0)
+ ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float()
+ ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0
+ gt_batch['ctc_loss_weight'] = ctc_loss_weight
+
cost = criterion(pre_batch, gt_batch)
- loss_avg.append(cost.item())
+ val_loss.update(cost.item())
+
_, preds = preds.max(2)
- preds = preds.squeeze(1)
- preds = preds.transpose(1, 0).contiguous().view(-1)
+ preds = preds.squeeze(1).transpose(1, 0).contiguous().view(-1)
+
sim_preds = LabelConverter.decode(preds.data, preds_size.data, raw=False)
for pred, target in zip(sim_preds, labels):
if pred == target:
- n_correct += 1
+ n_correct.update(1)
raw_preds = LabelConverter.decode(preds.data, preds_size.data, raw=True)[:config['base']['show_num']]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
- print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
+ train_log.info('recog example %-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
- val_acc = n_correct / float(len(test_data_loader) * config['testload']['batch_size'])
- val_loss = np.mean(loss_avg)
- print('Test loss: %f, accuray: %f' % (val_loss, val_acc))
- return val_acc,val_loss
+ val_acc = n_correct.sum / float(len(val_data_loader) * config['valload']['batch_size'])
+ train_log.info('val loss: %f, val accuray: %f' % (val_loss.avg, val_acc))
+ return val_acc
-def TrainValProgram(config):
+
+def TrainValProgram(args):
config = yaml.load(open(args.config, 'r', encoding='utf-8'),Loader=yaml.FullLoader)
config = merge_config(config,args)
@@ -195,8 +130,7 @@ def TrainValProgram(config):
os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id']
create_dir(config['base']['checkpoints'])
- checkpoints = os.path.join(config['base']['checkpoints'],
- "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'],
+ checkpoints = os.path.join(config['base']['checkpoints'],"ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'],
config['backbone']['function'].split(',')[-1],
config['head']['function'].split(',')[-1],
config['trainload']['batch_size'],
@@ -208,40 +142,36 @@ def TrainValProgram(config):
config['base']['classes'] = len(LabelConverter.alphabet)
model = create_module(config['architectures']['model_function'])(config)
criterion = create_module(config['architectures']['loss_function'])(config)
- train_dataset = create_module(config['trainload']['function'])(config)
- test_dataset = create_module(config['testload']['function'])(config)
- optimizer = create_module(config['optimizer']['function'])(config, model)
+ train_dataset = create_module(config['trainload']['function'])(config,'train')
+ val_dataset = create_module(config['valload']['function'])(config,'val')
+ optimizer = create_module(config['optimizer']['function'])(config, model.parameters())
optimizer_decay = create_module(config['optimizer_decay']['function'])
-
- if config['loss']['use_center']:
- center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['classes'])
- optimizer_center = torch.optim.Adam(center_criterion.parameters(), lr= config['loss']['center_lr'])
- optimizer_decay_center = create_module(config['optimizer_decay_center']['function'])
- else:
- center_criterion = None
- optimizer_center=None
-
-
+ if os.path.exists(os.path.join(checkpoints,'train_log.txt')):
+ os.remove(os.path.join(checkpoints,'train_log.txt'))
+ train_log = TrainLog(os.path.join(checkpoints,'train_log.txt'))
+ train_log.info(model)
+
train_data_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config['trainload']['batch_size'],
shuffle=True,
num_workers=config['trainload']['num_workers'],
- worker_init_fn = worker_init_fn,
collate_fn = alignCollate(),
drop_last=True,
pin_memory=True)
- test_data_loader = torch.utils.data.DataLoader(
- test_dataset,
- batch_size=config['testload']['batch_size'],
+ val_data_loader = torch.utils.data.DataLoader(
+ val_dataset,
+ batch_size=config['valload']['batch_size'],
shuffle=False,
- num_workers=config['testload']['num_workers'],
+ num_workers=config['valload']['num_workers'],
collate_fn = alignCollate(),
drop_last=True,
pin_memory=True)
-
- loss_bin = create_loss_bin(config['base']['algorithm'],use_center=config['loss']['use_center'])
+
+ loss_dict = {}
+ for title in config['loss']['loss_title']:
+ loss_dict[title] = AverageMeter()
if torch.cuda.is_available():
if (len(config['base']['gpu_id'].split(',')) > 1):
@@ -254,51 +184,19 @@ def TrainValProgram(config):
val_acc = 0
val_loss = 0
best_acc = 0
-
-
+
if config['base']['restore']:
- print('Resuming from checkpoint.')
+ train_log.info('Resuming from checkpoint.')
assert os.path.isfile(config['base']['restore_file']), 'Error: no checkpoint file found!'
- checkpoint = torch.load(config['base']['restore_file'])
- start_epoch = checkpoint['epoch']
- model.load_state_dict(checkpoint['state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- best_acc = checkpoint['best_acc']
- if not config['loss']['use_center']:
- log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'], resume=True)
- if config['loss']['use_center']:
- if os.path.exists(os.path.join(checkpoints, 'log_center.txt')):
- log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'], resume=True)
- else:
- log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'])
- title = list(loss_bin.keys())
- title.extend(['val_loss','test_acc','best_acc'])
- log_write.set_names(title)
- else:
- print('Training from scratch.')
- log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'])
- title = list(loss_bin.keys())
- title.extend(['val_loss','test_acc','best_acc'])
- log_write.set_names(title)
- center_flag = False
- center_model = None
- if config['base']['finetune']:
- start_epoch = 0
- optimizer.param_groups[0]['lr'] = 0.0001
- center_flag = True
- center_model = copy.deepcopy(model)
+ model,optimizer,start_epoch,best_acc = restore_training(config['base']['restore_file'],model,optimizer)
+
for epoch in range(start_epoch,config['base']['n_epoch']):
model.train()
optimizer_decay(config, optimizer, epoch)
- if config['loss']['use_center']:
- optimizer_decay_center(config, optimizer_center, epoch)
- loss_write,loss_flag = ModelTrain(train_data_loader,LabelConverter, model,center_model, criterion, optimizer, center_criterion,optimizer_center,center_flag,loss_bin, config, epoch)
-# if loss_flag < config['loss']['min_score']:
-# center_flag = True
+ ModelTrain(train_data_loader,LabelConverter, model,criterion,optimizer, train_log,loss_dict, config, epoch)
if(epoch >= config['base']['start_val']):
model.eval()
- val_acc,val_loss = ModelEval(test_data_loader,LabelConverter, model,criterion ,config)
- print('val_acc:',val_acc,'val_loss',val_loss)
+ val_acc = ModelEval(val_data_loader,LabelConverter, model,criterion,train_log,loss_dict,config)
if (val_acc > best_acc):
save_checkpoint({
'epoch': epoch + 1,
@@ -308,13 +206,11 @@ def TrainValProgram(config):
'best_acc': val_acc
}, checkpoints, config['base']['algorithm'] + '_best' + '.pth.tar')
best_acc = val_acc
-
-
- loss_write.extend([val_loss,val_acc,best_acc])
- log_write.append(loss_write)
- for key in loss_bin.keys():
- loss_bin[key].loss_clear()
- if epoch%config['base']['save_epoch'] ==0:
+ train_log.info('best_acc:' + str(best_acc))
+ for key in loss_dict.keys():
+ loss_dict[key].reset()
+
+ if epoch % config['base']['save_epoch'] == 0:
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
@@ -322,15 +218,12 @@ def TrainValProgram(config):
'optimizer': optimizer.state_dict(),
'best_acc': 0
},checkpoints,config['base']['algorithm']+'_'+str(epoch)+'.pth.tar')
-
+
if __name__ == "__main__":
-
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--config', help='config file path')
parser.add_argument('--log_str', help='log title')
args = parser.parse_args()
-
-
TrainValProgram(args)
\ No newline at end of file
diff --git a/tools/rec_train_bk1.py b/tools/rec_train_bk1.py
new file mode 100644
index 0000000..811d27c
--- /dev/null
+++ b/tools/rec_train_bk1.py
@@ -0,0 +1,396 @@
+# -*- coding:utf-8 _*-
+"""
+@author:fxw
+@file: det_train.py
+@time: 2020/08/07
+"""
+import sys
+sys.path.append('./')
+import cv2
+import torch
+import time
+import os
+import argparse
+import random
+import numpy as np
+from tqdm import tqdm
+from torch.autograd import Variable
+np.seterr(divide='ignore', invalid='ignore')
+import yaml
+import torch.utils.data
+from ptocr.utils.util_function import create_module, create_loss_bin, \
+set_seed,save_checkpoint,create_dir
+from ptocr.utils.metrics import runningScore
+from ptocr.utils.logger import Logger
+from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter
+from ptocr.dataloader.RecLoad.CRNNProcess import alignCollate
+import copy
+import re
+
+GLOBAL_WORKER_ID = None
+GLOBAL_SEED = 2020
+
+
+torch.manual_seed(GLOBAL_SEED)
+torch.cuda.manual_seed(GLOBAL_SEED)
+torch.cuda.manual_seed_all(GLOBAL_SEED)
+np.random.seed(GLOBAL_SEED)
+random.seed(GLOBAL_SEED)
+
+
+
+def worker_init_fn(worker_id):
+ global GLOBAL_WORKER_ID
+ GLOBAL_WORKER_ID = worker_id
+ set_seed(GLOBAL_SEED + worker_id)
+
+# def backward_hook(self,grad_input, grad_output):
+# for g in grad_input:
+# g[g != g] = 0 # replace all nan/inf in gradients to zero
+
+
+def ModelEval(test_data_loaders,LabelConverter,model,criterion,config):
+ loss_all= []
+ acc_all = []
+ for test_data_loader in test_data_loaders:
+ bar = tqdm(total=len(test_data_loader))
+ loss_avg = []
+ n_correct = 0
+ for batch_idx, (imgs, labels) in enumerate(test_data_loader):
+ bar.update(1)
+ pre_batch = {}
+ gt_batch = {}
+ if torch.cuda.is_available():
+ imgs = imgs.cuda()
+ with torch.no_grad():
+ preds,feau = model(imgs)
+ preds = preds.permute(1, 0, 2)
+
+ labels_class, labels_len = LabelConverter.encode(labels,preds.size(0))
+ preds_size = Variable(torch.IntTensor([preds.size(0)] * config['valload']['batch_size']))
+ pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size
+ gt_batch['labels'],gt_batch['labels_len'] = labels_class,labels_len
+
+ if config['loss']['use_ctc_weight']:
+ len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0
+ len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0)
+ ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float()
+ ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0
+ gt_batch['ctc_loss_weight'] = ctc_loss_weight
+
+ cost = criterion(pre_batch, gt_batch)
+ loss_avg.append(cost.item())
+ _, preds = preds.max(2)
+ preds = preds.squeeze(1)
+ preds = preds.contiguous().view(-1)
+ sim_preds = LabelConverter.decode(preds.data, preds_size.data, raw=False)
+ for pred, target in zip([sim_preds], labels):
+ target = ''.join(re.findall('[0-9a-zA-Z]+',target)).lower()
+ if pred == target:
+ n_correct += 1
+ val_acc = n_correct / float(len(test_data_loader) * config['valload']['batch_size'])
+ val_loss = np.mean(loss_avg)
+ loss_all.append(val_loss)
+ acc_all.append(val_acc)
+ print('Test acc:' ,acc_all)
+ return loss_all,acc_all
+
+def ModelTrain(train_data_loader,test_data_loader,LabelConverter,model,center_model, criterion, optimizer,center_criterion,optimizer_center,center_flag,loss_bin, config,optimizer_decay,optimizer_decay_center,log_write,checkpoints):
+ batch_time = AverageMeter()
+ end = time.time()
+ best_acc = 0
+ dataloader_bin = []
+ fid = open('test.txt','w+')
+ for i in range(len(train_data_loader)):
+ dataloader_bin.append(enumerate(train_data_loader[i]))
+
+ for iters in range(config['base']['start_iters'],config['base']['max_iters']):
+ model.train()
+ optimizer_decay(config, optimizer, iters)
+ if config['loss']['use_center']:
+ optimizer_decay_center(config, optimizer_center, iters)
+
+ imgs = []
+ labels = []
+ try:
+ for i in range(len(train_data_loader)):
+ index,(img,label) = next(dataloader_bin[i])
+ imgs.append(img)
+ labels.extend(label)
+ imgs = torch.cat(imgs,0)
+ except:
+ for i in range(len(train_data_loader)):
+ dataloader_bin[i] = enumerate(train_data_loader[i])
+ continue
+
+ pre_batch = {}
+ gt_batch = {}
+
+ if torch.cuda.is_available():
+ imgs = imgs.cuda()
+ preds,feau = model(imgs)
+
+ preds = preds.permute(1, 0, 2)
+
+ labels,labels_len = LabelConverter.encode(labels,preds.size(0))
+ preds_size = Variable(torch.IntTensor([preds.size(0)] * config['trainload']['batch_size']))
+ pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size
+ gt_batch['labels'],gt_batch['labels_len'] = labels,labels_len
+ #########
+
+# import pdb
+# pdb.set_trace()
+
+ if config['loss']['use_ctc_weight']:
+# print('use')
+
+ len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0
+ len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0)
+ ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float()
+
+ ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0
+ gt_batch['ctc_loss_weight'] = ctc_loss_weight
+
+ ctc_loss = criterion(pre_batch, gt_batch).cuda()
+ metrics = {}
+ metrics['loss_total'] = 0.0
+ metrics['loss_center'] = 0.0
+ if center_criterion is not None and center_flag is True:
+ center_model.eval()
+ #####
+ feautures = preds.clone()
+ with torch.no_grad():
+ center_preds,center_feau = center_model(imgs)
+ center_preds = center_preds.permute(1, 0, 2)
+
+ center_preds = torch.softmax(center_preds,-1)
+ confs, center_preds = center_preds.max(2)
+ center_preds = center_preds.squeeze(1).transpose(1, 0).contiguous()
+ confs = confs.transpose(1, 0).contiguous()
+
+# confs = []
+# for i in range(center_preds.shape[0]):
+# conf = []
+# for j in range(len(center_preds[i])):
+# conf.append(probs[i,j,center_preds[i][j]])
+# confs.append(conf)
+# confs = torch.Tensor(confs).cuda()
+
+ b,t = center_preds.shape
+
+# feautures = feautures.transpose(1, 0).contiguous()
+ feautures = center_feau[0].transpose(1, 0).contiguous()
+
+# import pdb
+# pdb.set_trace()
+ ### 去重复
+ repeat_index = (center_preds[:,:-1] == center_preds[:,1:])
+ center_preds[:,:-1][repeat_index] = 0
+
+ confs = confs.view(-1)
+ center_preds = center_preds.view(-1)
+ feautures = feautures.view(b*t,-1)
+
+
+ index = (center_preds>0) & (confs>config['loss']['label_score'])
+ center_preds = center_preds[index]
+ feautures = feautures[index]
+
+ center_loss = center_criterion(feautures,center_preds)*config['loss']['weight_center']
+
+ loss = ctc_loss + center_loss
+
+
+ metrics['loss_total'] = loss.item()
+ metrics['loss_ctc'] = ctc_loss.item()
+ metrics['loss_center'] = center_loss.item()
+
+ #####
+ optimizer_center.zero_grad()
+ optimizer.zero_grad()
+
+ loss.backward()
+ optimizer.step()
+
+ for param in center_criterion.parameters():
+ param.grad.data *= (1. / config['loss']['weight_center'])
+ optimizer_center.step()
+ else:
+ loss = ctc_loss
+ metrics['loss_ctc'] = ctc_loss.item()
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ for key in loss_bin.keys():
+ loss_bin[key].loss_add(metrics[key])
+ batch_time.update(time.time() - end)
+ end = time.time()
+ if (iters % config['base']['show_step'] == 0):
+ log = '({}/{}) | ' \
+ .format(iters,config['base']['max_iters'])
+ bin_keys = list(loss_bin.keys())
+
+ for i in range(len(bin_keys)):
+ log += bin_keys[i] + ':{:.4f}'.format(loss_bin[bin_keys[i]].loss_mean()) + ' | '
+ log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])+ ' | '
+ log+='batch_time:{:.2f} s'.format(batch_time.avg)+ ' | '
+ log+='total_time:{:.2f} min'.format(batch_time.avg * iters / 60.0)+ ' | '
+ log+='ETA:{:.2f} min'.format(batch_time.avg*(config['base']['max_iters']-iters)/60.0)
+ print(log)
+
+
+ if(iters % config['base']['eval_iter']==0 and iters!=0):
+
+ loss_write = []
+ for key in list(loss_bin.keys()):
+ loss_write.append(loss_bin[key].loss_mean())
+
+ model.eval()
+ val_loss,acc_all = ModelEval(test_data_loader,LabelConverter, model,criterion ,config)
+ val_acc = np.mean(acc_all)
+ if (val_acc > best_acc):
+ save_checkpoint({
+ 'iters': iters + 1,
+ 'state_dict': model.state_dict(),
+ 'lr': config['optimizer']['base_lr'],
+ 'optimizer': optimizer.state_dict(),
+ 'best_acc': val_acc
+ }, checkpoints, config['base']['algorithm'] + '_best' + '.pth.tar')
+ best_acc = val_acc
+ acc_all.append(val_acc)
+ acc_all = [str(x) for x in acc_all]
+ fid.write(str(','.join(acc_all))+'\n')
+ fid.flush()
+ loss_write.extend([0,0,0])
+ log_write.append(loss_write)
+ for key in loss_bin.keys():
+ loss_bin[key].loss_clear()
+ if iters %config['base']['eval_iter'] ==0:
+ save_checkpoint({
+ 'iters': iters + 1,
+ 'state_dict': model.state_dict(),
+ 'lr': config['optimizer']['base_lr'],
+ 'optimizer': optimizer.state_dict(),
+ 'best_acc': 0
+ },checkpoints,config['base']['algorithm']+'_'+str(iters)+'.pth.tar')
+
+
+
+
+
+
+
+
+
+def TrainValProgram(config):
+
+ config = yaml.load(open(args.config, 'r', encoding='utf-8'),Loader=yaml.FullLoader)
+ config = merge_config(config,args)
+
+ os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id']
+
+ create_dir(config['base']['checkpoints'])
+ checkpoints = os.path.join(config['base']['checkpoints'],
+ "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'],
+ config['backbone']['function'].split(',')[-1],
+ config['head']['function'].split(',')[-1],
+ config['trainload']['batch_size'],
+ config['base']['max_iters'],
+ args.log_str))
+ create_dir(checkpoints)
+
+ LabelConverter = create_module(config['label_transform']['function'])(config)
+ config['base']['classes'] = len(LabelConverter.alphabet)
+ model = create_module(config['architectures']['model_function'])(config)
+ criterion = create_module(config['architectures']['loss_function'])(config)
+ train_data_loader = create_module(config['trainload']['function'])(config)
+ test_data_loader = create_module(config['valload']['function'])(config)
+ optimizer = create_module(config['optimizer']['function'])(config, model)
+ optimizer_decay = create_module(config['optimizer_decay']['function'])
+
+ if config['loss']['use_center']:
+# center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['classes'])
+ center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['hiddenchannel'])
+
+ optimizer_center = torch.optim.Adam(center_criterion.parameters(), lr= config['loss']['center_lr'])
+ optimizer_decay_center = create_module(config['optimizer_decay_center']['function'])
+ else:
+ center_criterion = None
+ optimizer_center=None
+ optimizer_decay_center = None
+
+
+
+
+ loss_bin = create_loss_bin(config['base']['algorithm'],use_center=config['loss']['use_center'])
+
+ if torch.cuda.is_available():
+ if (len(config['base']['gpu_id'].split(',')) > 1):
+ model = torch.nn.DataParallel(model).cuda()
+ else:
+ model = model.cuda()
+ criterion = criterion.cuda()
+
+
+
+ print(model)
+
+ ### model.head.lstm_2.embedding = nn.Linear(in_features=512, out_features=2000, bias=True)
+
+ if config['base']['restore']:
+ print('Resuming from checkpoint.')
+ assert os.path.isfile(config['base']['restore_file']), 'Error: no checkpoint file found!'
+ checkpoint = torch.load(config['base']['restore_file'])
+ start_epoch = checkpoint['epoch']
+# model.load_state_dict(checkpoint['state_dict'])
+ try:
+ model.load_state_dict(checkpoint['state_dict'])
+ except:
+ state = model.state_dict()
+ for key in state.keys():
+ state[key] = checkpoint['state_dict'][key[7:]]
+ model.load_state_dict(state)
+
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ best_acc = checkpoint['best_acc']
+ if not config['loss']['use_center']:
+ log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'], resume=True)
+ if config['loss']['use_center']:
+ if os.path.exists(os.path.join(checkpoints, 'log_center.txt')):
+ log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'], resume=True)
+ else:
+ log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'])
+ title = list(loss_bin.keys())
+ title.extend(['val_loss','test_acc','best_acc'])
+ log_write.set_names(title)
+ else:
+ print('Training from scratch.')
+ log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'])
+ title = list(loss_bin.keys())
+ title.extend(['val_loss','test_acc','best_acc'])
+ log_write.set_names(title)
+ center_flag = False
+ center_model = None
+ if config['base']['finetune']:
+ start_epoch = 0
+ optimizer.param_groups[0]['lr'] = 0.0001
+ center_flag = True
+ center_model = copy.deepcopy(model)
+
+
+ loss_write,loss_flag = ModelTrain(train_data_loader,test_data_loader,LabelConverter, model,center_model, criterion, optimizer, center_criterion,optimizer_center,center_flag,loss_bin, config,optimizer_decay,optimizer_decay_center,log_write,checkpoints)
+
+
+
+
+if __name__ == "__main__":
+
+
+ parser = argparse.ArgumentParser(description='Hyperparams')
+ parser.add_argument('--config', help='config file path')
+ parser.add_argument('--log_str', help='log title')
+ args = parser.parse_args()
+
+
+ TrainValProgram(args)
\ No newline at end of file
diff --git a/tools/rec_train_bk2.py b/tools/rec_train_bk2.py
new file mode 100644
index 0000000..ce6a9eb
--- /dev/null
+++ b/tools/rec_train_bk2.py
@@ -0,0 +1,332 @@
+# -*- coding:utf-8 _*-
+"""
+@author:fxw
+@file: det_train.py
+@time: 2020/08/07
+"""
+import sys
+sys.path.append('./')
+import cv2
+import torch
+import time
+import os
+import argparse
+import random
+import numpy as np
+from tqdm import tqdm
+from torch.autograd import Variable
+np.seterr(divide='ignore', invalid='ignore')
+import yaml
+import torch.utils.data
+from ptocr.utils.util_function import create_module, create_loss_bin, \
+set_seed,save_checkpoint,create_dir
+from ptocr.utils.metrics import runningScore
+from ptocr.utils.logger import Logger
+from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter
+from ptocr.dataloader.RecLoad.CRNNProcess import alignCollate
+import copy
+import re
+
+GLOBAL_WORKER_ID = None
+GLOBAL_SEED = 2020
+
+
+torch.manual_seed(GLOBAL_SEED)
+torch.cuda.manual_seed(GLOBAL_SEED)
+torch.cuda.manual_seed_all(GLOBAL_SEED)
+np.random.seed(GLOBAL_SEED)
+random.seed(GLOBAL_SEED)
+
+
+
+def worker_init_fn(worker_id):
+ global GLOBAL_WORKER_ID
+ GLOBAL_WORKER_ID = worker_id
+ set_seed(GLOBAL_SEED + worker_id)
+
+# def backward_hook(self,grad_input, grad_output):
+# for g in grad_input:
+# g[g != g] = 0 # replace all nan/inf in gradients to zero
+
+
+def ModelEval(test_data_loaders,LabelConverter,model,criterion,config):
+ loss_all= []
+ acc_all = []
+ for test_data_loader in test_data_loaders:
+ bar = tqdm(total=len(test_data_loader))
+ loss_avg = []
+ n_correct = 0
+ for batch_idx, (imgs, labels) in enumerate(test_data_loader):
+ bar.update(1)
+ pre_batch = {}
+ gt_batch = {}
+ if torch.cuda.is_available():
+ imgs = imgs.cuda()
+ with torch.no_grad():
+ preds,feau = model(imgs)
+
+
+ _,_, labels_class = LabelConverter.test_encode(labels)
+
+ pre_batch['pred'],gt_batch['gt'] = preds,labels_class
+
+
+ if config['loss']['use_ctc_weight']:
+ len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0
+ len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0)
+ ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float()
+ ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0
+ gt_batch['ctc_loss_weight'] = ctc_loss_weight
+
+ cost,_ = criterion(pre_batch, gt_batch)
+ loss_avg.append(cost.item())
+ _, preds = preds.max(2)
+
+ sim_preds = LabelConverter.decode(preds.data)
+ for pred, target in zip(sim_preds, labels):
+ target = ''.join(re.findall('[0-9a-zA-Z]+',target)).lower()
+ if pred == target:
+ n_correct += 1
+ val_acc = n_correct / float(len(test_data_loader) * config['valload']['batch_size'])
+ val_loss = np.mean(loss_avg)
+ loss_all.append(val_loss)
+ acc_all.append(val_acc)
+ print('Test acc:' ,acc_all)
+ return loss_all,acc_all
+
+def ModelTrain(train_data_loader,test_data_loader,LabelConverter,model,center_model, criterion, optimizer,center_criterion,optimizer_center,center_flag,loss_bin, config,optimizer_decay,optimizer_decay_center,log_write,checkpoints):
+ batch_time = AverageMeter()
+ end = time.time()
+ best_acc = 0
+ dataloader_bin = []
+ fid = open('test.txt','w+')
+ for i in range(len(train_data_loader)):
+ dataloader_bin.append(enumerate(train_data_loader[i]))
+
+ for iters in range(config['base']['start_iters'],config['base']['max_iters']):
+ model.train()
+ optimizer_decay(config, optimizer, iters)
+ if config['loss']['use_center']:
+ optimizer_decay_center(config, optimizer_center, iters)
+
+ imgs = []
+ labels = []
+ try:
+ for i in range(len(train_data_loader)):
+ index,(img,label) = next(dataloader_bin[i])
+ imgs.append(img)
+ labels.extend(label)
+ imgs = torch.cat(imgs,0)
+ except:
+ for i in range(len(train_data_loader)):
+ dataloader_bin[i] = enumerate(train_data_loader[i])
+ continue
+
+ pre_batch = {}
+ gt_batch = {}
+
+ _, _, labels = LabelConverter.train_encode(labels)
+
+ if torch.cuda.is_available():
+ imgs = imgs.cuda()
+ preds,feau = model(imgs)
+
+
+ pre_batch['pred'],gt_batch['gt'] = preds,labels.long().cuda()
+ #########
+
+
+
+ if config['loss']['use_ctc_weight']:
+# print('use')
+
+ len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0
+ len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0)
+ ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float()
+
+ ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0
+ gt_batch['ctc_loss_weight'] = ctc_loss_weight
+
+ loss,_ = criterion(pre_batch, gt_batch)
+
+# import pdb
+# pdb.set_trace()
+
+ metrics = {}
+ metrics['loss_fc'] = loss.item()
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ for key in loss_bin.keys():
+ loss_bin[key].loss_add(metrics[key])
+ batch_time.update(time.time() - end)
+ end = time.time()
+ if (iters % config['base']['show_step'] == 0):
+ log = '({}/{}) | ' \
+ .format(iters,config['base']['max_iters'])
+ bin_keys = list(loss_bin.keys())
+
+ for i in range(len(bin_keys)):
+ log += bin_keys[i] + ':{:.4f}'.format(loss_bin[bin_keys[i]].loss_mean()) + ' | '
+ log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])+ ' | '
+ log+='batch_time:{:.2f} s'.format(batch_time.avg)+ ' | '
+ log+='total_time:{:.2f} min'.format(batch_time.avg * iters / 60.0)+ ' | '
+ log+='ETA:{:.2f} min'.format(batch_time.avg*(config['base']['max_iters']-iters)/60.0)
+ print(log)
+
+
+ if(iters % config['base']['eval_iter']==0 and iters!=0):
+
+ loss_write = []
+ for key in list(loss_bin.keys()):
+ loss_write.append(loss_bin[key].loss_mean())
+
+ model.eval()
+ val_loss,acc_all = ModelEval(test_data_loader,LabelConverter, model,criterion ,config)
+ val_acc = np.mean(acc_all)
+ if (val_acc > best_acc):
+ save_checkpoint({
+ 'iters': iters + 1,
+ 'state_dict': model.state_dict(),
+ 'lr': config['optimizer']['base_lr'],
+ 'optimizer': optimizer.state_dict(),
+ 'best_acc': val_acc
+ }, checkpoints, config['base']['algorithm'] + '_best' + '.pth.tar')
+ best_acc = val_acc
+ acc_all.append(val_acc)
+ acc_all = [str(x) for x in acc_all]
+ fid.write(str(','.join(acc_all))+'\n')
+ fid.flush()
+ loss_write.extend([0,0,0])
+ log_write.append(loss_write)
+ for key in loss_bin.keys():
+ loss_bin[key].loss_clear()
+ if iters %config['base']['eval_iter'] ==0:
+ save_checkpoint({
+ 'iters': iters + 1,
+ 'state_dict': model.state_dict(),
+ 'lr': config['optimizer']['base_lr'],
+ 'optimizer': optimizer.state_dict(),
+ 'best_acc': 0
+ },checkpoints,config['base']['algorithm']+'_'+str(iters)+'.pth.tar')
+
+
+
+
+
+
+
+
+
+def TrainValProgram(config):
+
+ config = yaml.load(open(args.config, 'r', encoding='utf-8'),Loader=yaml.FullLoader)
+ config = merge_config(config,args)
+
+ os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id']
+
+ create_dir(config['base']['checkpoints'])
+ checkpoints = os.path.join(config['base']['checkpoints'],
+ "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'],
+ config['backbone']['function'].split(',')[-1],
+ config['head']['function'].split(',')[-1],
+ config['trainload']['batch_size'],
+ config['base']['max_iters'],
+ args.log_str))
+ create_dir(checkpoints)
+
+ LabelConverter = create_module(config['label_transform']['function'])(config)
+
+ model = create_module(config['architectures']['model_function'])(config)
+ criterion = create_module(config['architectures']['loss_function'])(config)
+ train_data_loader = create_module(config['trainload']['function'])(config)
+ test_data_loader = create_module(config['valload']['function'])(config)
+ optimizer = create_module(config['optimizer']['function'])(config, model)
+ optimizer_decay = create_module(config['optimizer_decay']['function'])
+
+ if config['loss']['use_center']:
+# center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['classes'])
+ center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['hiddenchannel'])
+
+ optimizer_center = torch.optim.Adam(center_criterion.parameters(), lr= config['loss']['center_lr'])
+ optimizer_decay_center = create_module(config['optimizer_decay_center']['function'])
+ else:
+ center_criterion = None
+ optimizer_center=None
+ optimizer_decay_center = None
+
+
+
+
+ loss_bin = create_loss_bin(config['base']['algorithm'],use_center=config['loss']['use_center'])
+
+ if torch.cuda.is_available():
+ if (len(config['base']['gpu_id'].split(',')) > 1):
+ model = torch.nn.DataParallel(model).cuda()
+ else:
+ model = model.cuda()
+ criterion = criterion.cuda()
+
+
+
+ print(model)
+
+ ### model.head.lstm_2.embedding = nn.Linear(in_features=512, out_features=2000, bias=True)
+
+ if config['base']['restore']:
+ print('Resuming from checkpoint.')
+ assert os.path.isfile(config['base']['restore_file']), 'Error: no checkpoint file found!'
+ checkpoint = torch.load(config['base']['restore_file'])
+ start_epoch = checkpoint['epoch']
+# model.load_state_dict(checkpoint['state_dict'])
+ try:
+ model.load_state_dict(checkpoint['state_dict'])
+ except:
+ state = model.state_dict()
+ for key in state.keys():
+ state[key] = checkpoint['state_dict'][key[7:]]
+ model.load_state_dict(state)
+
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ best_acc = checkpoint['best_acc']
+ if not config['loss']['use_center']:
+ log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'], resume=True)
+ if config['loss']['use_center']:
+ if os.path.exists(os.path.join(checkpoints, 'log_center.txt')):
+ log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'], resume=True)
+ else:
+ log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'])
+ title = list(loss_bin.keys())
+ title.extend(['val_loss','test_acc','best_acc'])
+ log_write.set_names(title)
+ else:
+ print('Training from scratch.')
+ log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'])
+ title = list(loss_bin.keys())
+ title.extend(['val_loss','test_acc','best_acc'])
+ log_write.set_names(title)
+ center_flag = False
+ center_model = None
+ if config['base']['finetune']:
+ start_epoch = 0
+ optimizer.param_groups[0]['lr'] = 0.0001
+ center_flag = True
+ center_model = copy.deepcopy(model)
+
+
+ loss_write,loss_flag = ModelTrain(train_data_loader,test_data_loader,LabelConverter, model,center_model, criterion, optimizer, center_criterion,optimizer_center,center_flag,loss_bin, config,optimizer_decay,optimizer_decay_center,log_write,checkpoints)
+
+
+
+
+if __name__ == "__main__":
+
+
+ parser = argparse.ArgumentParser(description='Hyperparams')
+ parser.add_argument('--config', help='config file path')
+ parser.add_argument('--log_str', help='log title')
+ args = parser.parse_args()
+
+
+ TrainValProgram(args)
\ No newline at end of file
diff --git a/tools/rec_train_bk3.py b/tools/rec_train_bk3.py
new file mode 100644
index 0000000..67d6c11
--- /dev/null
+++ b/tools/rec_train_bk3.py
@@ -0,0 +1,376 @@
+# -*- coding:utf-8 _*-
+"""
+@author:fxw
+@file: det_train.py
+@time: 2020/08/07
+"""
+import sys
+sys.path.append('./')
+import cv2
+import torch
+import time
+import os
+import argparse
+import random
+import numpy as np
+from tqdm import tqdm
+from torch.autograd import Variable
+np.seterr(divide='ignore', invalid='ignore')
+import yaml
+import torch.utils.data
+from ptocr.utils.util_function import create_module, create_loss_bin, \
+set_seed,save_checkpoint,create_dir
+from ptocr.utils.metrics import runningScore
+from ptocr.utils.logger import Logger
+from ptocr.utils.util_function import create_process_obj,merge_config,AverageMeter
+from ptocr.dataloader.RecLoad.CRNNProcess import alignCollate
+import copy
+
+GLOBAL_WORKER_ID = None
+GLOBAL_SEED = 2020
+
+
+torch.manual_seed(GLOBAL_SEED)
+torch.cuda.manual_seed(GLOBAL_SEED)
+torch.cuda.manual_seed_all(GLOBAL_SEED)
+np.random.seed(GLOBAL_SEED)
+random.seed(GLOBAL_SEED)
+
+
+
+def worker_init_fn(worker_id):
+ global GLOBAL_WORKER_ID
+ GLOBAL_WORKER_ID = worker_id
+ set_seed(GLOBAL_SEED + worker_id)
+
+# def backward_hook(self,grad_input, grad_output):
+# for g in grad_input:
+# g[g != g] = 0 # replace all nan/inf in gradients to zero
+
+
+def ModelTrain(train_data_loader,LabelConverter,model,center_model, criterion, optimizer,center_criterion,optimizer_center,center_flag,loss_bin, config, epoch):
+ batch_time = AverageMeter()
+ end = time.time()
+ for batch_idx, data in enumerate(train_data_loader):
+# model.register_backward_hook(backward_hook)
+ if(data is None):
+ continue
+ imgs,labels = data
+ pre_batch = {}
+ gt_batch = {}
+
+ if torch.cuda.is_available():
+ imgs = imgs.cuda()
+ preds,feau = model(imgs)
+
+ preds = preds.permute(1, 0, 2)
+
+ labels,labels_len = LabelConverter.encode(labels,preds.size(0))
+ preds_size = Variable(torch.IntTensor([preds.size(0)] * config['trainload']['batch_size']))
+ pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size
+ gt_batch['labels'],gt_batch['labels_len'] = labels,labels_len
+ #########
+ if config['loss']['use_ctc_weight']:
+
+ len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0
+ len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0)
+ ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float()
+
+ ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0
+ gt_batch['ctc_loss_weight'] = ctc_loss_weight
+
+ ctc_loss = criterion(pre_batch, gt_batch).cuda()
+ metrics = {}
+ metrics['loss_total'] = 0.0
+ metrics['loss_center'] = 0.0
+ if center_criterion is not None and center_flag is True:
+ center_model.eval()
+ #####
+ feautures = preds.clone()
+ with torch.no_grad():
+ center_preds,center_feau = center_model(imgs)
+ center_preds = center_preds.permute(1, 0, 2)
+
+ center_preds = torch.softmax(center_preds,-1)
+ confs, center_preds = center_preds.max(2)
+ center_preds = center_preds.squeeze(1).transpose(1, 0).contiguous()
+ confs = confs.transpose(1, 0).contiguous()
+
+# confs = []
+# for i in range(center_preds.shape[0]):
+# conf = []
+# for j in range(len(center_preds[i])):
+# conf.append(probs[i,j,center_preds[i][j]])
+# confs.append(conf)
+# confs = torch.Tensor(confs).cuda()
+
+ b,t = center_preds.shape
+
+# feautures = feautures.transpose(1, 0).contiguous()
+ feautures = center_feau[0].transpose(1, 0).contiguous()
+
+# import pdb
+# pdb.set_trace()
+ ### 去重复
+ repeat_index = (center_preds[:,:-1] == center_preds[:,1:])
+ center_preds[:,:-1][repeat_index] = 0
+
+ confs = confs.view(-1)
+ center_preds = center_preds.view(-1)
+ feautures = feautures.view(b*t,-1)
+
+
+ index = (center_preds>0) & (confs>config['loss']['label_score'])
+ center_preds = center_preds[index]
+ feautures = feautures[index]
+
+ center_loss = center_criterion(feautures,center_preds)*config['loss']['weight_center']
+
+ loss = ctc_loss + center_loss
+
+
+ metrics['loss_total'] = loss.item()
+ metrics['loss_ctc'] = ctc_loss.item()
+ metrics['loss_center'] = center_loss.item()
+
+ #####
+ optimizer_center.zero_grad()
+ optimizer.zero_grad()
+
+ loss.backward()
+ optimizer.step()
+
+ for param in center_criterion.parameters():
+ param.grad.data *= (1. / config['loss']['weight_center'])
+ optimizer_center.step()
+ else:
+ loss = ctc_loss
+ metrics['loss_ctc'] = ctc_loss.item()
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ for key in loss_bin.keys():
+ loss_bin[key].loss_add(metrics[key])
+ batch_time.update(time.time() - end)
+ end = time.time()
+ if (batch_idx % config['base']['show_step'] == 0):
+ log = '({}/{}/{}/{}) | ' \
+ .format(epoch, config['base']['n_epoch'], batch_idx, len(train_data_loader))
+ bin_keys = list(loss_bin.keys())
+
+ for i in range(len(bin_keys)):
+ log += bin_keys[i] + ':{:.4f}'.format(loss_bin[bin_keys[i]].loss_mean()) + ' | '
+ log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])+ ' | '
+ log+='batch_time:{:.2f} s'.format(batch_time.avg)+ ' | '
+ log+='total_time:{:.2f} min'.format(batch_time.avg * batch_idx / 60.0)+ ' | '
+ log+='ETA:{:.2f} min'.format(batch_time.avg*(len(train_data_loader)-batch_idx)/60.0)
+ print(log)
+ loss_write = []
+ for key in list(loss_bin.keys()):
+ loss_write.append(loss_bin[key].loss_mean())
+ return loss_write,loss_bin['loss_ctc'].loss_mean()
+
+
+def ModelEval(test_data_loader,LabelConverter,model,criterion,config):
+ bar = tqdm(total=len(test_data_loader))
+ loss_avg = []
+ n_correct = 0
+ for batch_idx, (imgs, labels) in enumerate(test_data_loader):
+ bar.update(1)
+ pre_batch = {}
+ gt_batch = {}
+ if torch.cuda.is_available():
+ imgs = imgs.cuda()
+ with torch.no_grad():
+ preds,feau = model(imgs)
+ preds = preds.permute(1, 0, 2)
+
+ labels_class, labels_len = LabelConverter.encode(labels,preds.size(0))
+ preds_size = Variable(torch.IntTensor([preds.size(0)] * config['testload']['batch_size']))
+ pre_batch['preds'],pre_batch['preds_size'] = preds,preds_size
+ gt_batch['labels'],gt_batch['labels_len'] = labels_class,labels_len
+
+ if config['loss']['use_ctc_weight']:
+ len_index = torch.softmax(preds,-1).max(2)[1].transpose(0,1)>0
+ len_flag = torch.cat([labels_len.cuda().long().unsqueeze(0),len_index.sum(1).unsqueeze(0)],0)
+ ctc_loss_weight = len_flag.max(0)[0].float()/len_flag.min(0)[0].float()
+ ctc_loss_weight[ctc_loss_weight==torch.tensor(np.inf).cuda()]=2.0
+ gt_batch['ctc_loss_weight'] = ctc_loss_weight
+
+ cost = criterion(pre_batch, gt_batch)
+ loss_avg.append(cost.item())
+ _, preds = preds.max(2)
+ preds = preds.squeeze(1)
+ preds = preds.transpose(1, 0).contiguous().view(-1)
+ sim_preds = LabelConverter.decode(preds.data, preds_size.data, raw=False)
+ for pred, target in zip(sim_preds, labels):
+ if pred == target:
+ n_correct += 1
+ raw_preds = LabelConverter.decode(preds.data, preds_size.data, raw=True)[:config['base']['show_num']]
+ for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
+ print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
+
+ val_acc = n_correct / float(len(test_data_loader) * config['testload']['batch_size'])
+ val_loss = np.mean(loss_avg)
+ print('Test loss: %f, accuray: %f' % (val_loss, val_acc))
+ return val_acc,val_loss
+
+def TrainValProgram(config):
+
+ config = yaml.load(open(args.config, 'r', encoding='utf-8'),Loader=yaml.FullLoader)
+ config = merge_config(config,args)
+
+ os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id']
+
+ create_dir(config['base']['checkpoints'])
+ checkpoints = os.path.join(config['base']['checkpoints'],
+ "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" % (config['base']['algorithm'],
+ config['backbone']['function'].split(',')[-1],
+ config['head']['function'].split(',')[-1],
+ config['trainload']['batch_size'],
+ config['base']['n_epoch'],
+ args.log_str))
+ create_dir(checkpoints)
+
+ LabelConverter = create_module(config['label_transform']['function'])(config)
+ config['base']['classes'] = len(LabelConverter.alphabet)
+ model = create_module(config['architectures']['model_function'])(config)
+ criterion = create_module(config['architectures']['loss_function'])(config)
+ train_dataset = create_module(config['trainload']['function'])(config)
+ test_dataset = create_module(config['testload']['function'])(config)
+ optimizer = create_module(config['optimizer']['function'])(config, model)
+ optimizer_decay = create_module(config['optimizer_decay']['function'])
+
+ if config['loss']['use_center']:
+# center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['classes'])
+ center_criterion = create_module(config['loss']['center_function'])(config['base']['classes'],config['base']['hiddenchannel'])
+
+ optimizer_center = torch.optim.Adam(center_criterion.parameters(), lr= config['loss']['center_lr'])
+ optimizer_decay_center = create_module(config['optimizer_decay_center']['function'])
+ else:
+ center_criterion = None
+ optimizer_center=None
+
+
+ train_data_loader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=config['trainload']['batch_size'],
+ shuffle=True,
+ num_workers=config['trainload']['num_workers'],
+ worker_init_fn = worker_init_fn,
+ collate_fn = alignCollate(),
+ drop_last=True,
+ pin_memory=True)
+
+ test_data_loader = torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=config['testload']['batch_size'],
+ shuffle=False,
+ num_workers=config['testload']['num_workers'],
+ collate_fn = alignCollate(),
+ drop_last=True,
+ pin_memory=True)
+
+ loss_bin = create_loss_bin(config['base']['algorithm'],use_center=config['loss']['use_center'])
+
+ if torch.cuda.is_available():
+ if (len(config['base']['gpu_id'].split(',')) > 1):
+ model = torch.nn.DataParallel(model).cuda()
+ else:
+ model = model.cuda()
+ criterion = criterion.cuda()
+
+ start_epoch = 0
+ val_acc = 0
+ val_loss = 0
+ best_acc = 0
+
+ print(model)
+
+ if config['base']['restore']:
+ print('Resuming from checkpoint.')
+ assert os.path.isfile(config['base']['restore_file']), 'Error: no checkpoint file found!'
+ checkpoint = torch.load(config['base']['restore_file'])
+ start_epoch = checkpoint['epoch']
+# model.load_state_dict(checkpoint['state_dict'])
+ try:
+ model.load_state_dict(checkpoint['state_dict'])
+ except:
+ state = model.state_dict()
+ for key in state.keys():
+ state[key] = checkpoint['state_dict'][key[7:]]
+ model.load_state_dict(state)
+
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ best_acc = checkpoint['best_acc']
+ if not config['loss']['use_center']:
+ log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'], resume=True)
+ if config['loss']['use_center']:
+ if os.path.exists(os.path.join(checkpoints, 'log_center.txt')):
+ log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'], resume=True)
+ else:
+ log_write = Logger(os.path.join(checkpoints, 'log_center.txt'), title=config['base']['algorithm'])
+ title = list(loss_bin.keys())
+ title.extend(['val_loss','test_acc','best_acc'])
+ log_write.set_names(title)
+ else:
+ print('Training from scratch.')
+ log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['base']['algorithm'])
+ title = list(loss_bin.keys())
+ title.extend(['val_loss','test_acc','best_acc'])
+ log_write.set_names(title)
+ center_flag = False
+ center_model = None
+ if config['base']['finetune']:
+ start_epoch = 0
+ optimizer.param_groups[0]['lr'] = 0.0001
+ center_flag = True
+ center_model = copy.deepcopy(model)
+ for epoch in range(start_epoch,config['base']['n_epoch']):
+ model.train()
+ optimizer_decay(config, optimizer, epoch)
+ if config['loss']['use_center']:
+ optimizer_decay_center(config, optimizer_center, epoch)
+ loss_write,loss_flag = ModelTrain(train_data_loader,LabelConverter, model,center_model, criterion, optimizer, center_criterion,optimizer_center,center_flag,loss_bin, config, epoch)
+# if loss_flag < config['loss']['min_score']:
+# center_flag = True
+ if(epoch >= config['base']['start_val']):
+ model.eval()
+ val_acc,val_loss = ModelEval(test_data_loader,LabelConverter, model,criterion ,config)
+ print('val_acc:',val_acc,'val_loss',val_loss)
+ if (val_acc > best_acc):
+ save_checkpoint({
+ 'epoch': epoch + 1,
+ 'state_dict': model.state_dict(),
+ 'lr': config['optimizer']['base_lr'],
+ 'optimizer': optimizer.state_dict(),
+ 'best_acc': val_acc
+ }, checkpoints, config['base']['algorithm'] + '_best' + '.pth.tar')
+ best_acc = val_acc
+
+
+ loss_write.extend([val_loss,val_acc,best_acc])
+ log_write.append(loss_write)
+ for key in loss_bin.keys():
+ loss_bin[key].loss_clear()
+ if epoch%config['base']['save_epoch'] ==0:
+ save_checkpoint({
+ 'epoch': epoch + 1,
+ 'state_dict': model.state_dict(),
+ 'lr': config['optimizer']['base_lr'],
+ 'optimizer': optimizer.state_dict(),
+ 'best_acc': 0
+ },checkpoints,config['base']['algorithm']+'_'+str(epoch)+'.pth.tar')
+
+
+if __name__ == "__main__":
+
+
+ parser = argparse.ArgumentParser(description='Hyperparams')
+ parser.add_argument('--config', help='config file path')
+ parser.add_argument('--log_str', help='log title')
+ args = parser.parse_args()
+
+
+ TrainValProgram(args)
\ No newline at end of file