CRNN κΈ°λ°μ OCR νκΈ λͺ¨λΈ Training
- CRNN_training.py: ~ λͺ¨λΈ νμ΅ λ° λͺ¨λΈ(κ°μ€μΉ) μ μ₯
- CRNN_train_test.ipynb: ~ λͺ¨λΈ νμ΅ λ° λͺ¨λΈ(κ°μ€μΉ) μ μ₯, loss μΆμ΄ κ·Έλν, λͺ¨λΈ test
CRNN_train_test.ipynb λ₯Ό ν΅ν΄ μ½λλ₯Ό ꡬμ±νμκ³ ,
μ΅μ’
λͺ¨λΈ νμ΅μ CRNN_train_test.ipynb λ₯Ό λ°νμΌλ‘ ν CRNN_training.py λ₯Ό ν°λ―Έλμμ μ€νν΄ λ°±κ·ΈλΌμ΄λμμλ νμ΅μ΄ μ§νλ μ μκ² νμμ΅λλ€.
μ€μ νμ΅μ CRNN_training.py μ½λλ₯Ό μ€ννλ©΄ λκ³ ,
λͺ¨λΈ λ° μ½λ ν
μ€νΈλ CRNN_train_test.ipynb μ½λλ₯Ό μ€ννλ©΄ λ©λλ€.
λ³Έ READMEλ CRNN_train_test.ipynb κΈ°μ€μΌλ‘ μμ±λμμ΅λλ€.
λ€μκ³Ό κ°μ λλ ν 리 κ΅¬μ‘°λ‘ μ΄λ£¨μ΄μ Έ μμ΅λλ€.
OCR_CRNN/
βββ printed/
β βββ 03343000.png
β βββ 03343001.png
β β ...
β βββ 03385349.png
β
βββ utils/
β βββ bboxes.py
β βββ losses.py
β βββ model.py
β βββ training.py
β
βββ CRNN_train_test.ipynb
βββ CRNN_training.py
βββ CRNN_model_2_v1.h5
βββ CRNN_model_2_v2.h5
βββ CRNN_weights_2_v1.h5
βββ CRNN_weights_2_v2.h5
βββ crnn_data.py
βββ crnn_model.py
βββ crnn_utils.py
βββ ssd_data.py
βββ korean_printed_sentence.json
βββ NanumBarunGothic.ttf
βββ requirements.txt
$ pip install -r requirements.txt
μ 컀맨λλ₯Ό μ€νν΄ λͺ¨λΈ νμ΅μ νμν ν¨ν€μ§λ₯Ό μ€μΉν©λλ€. CRNN_train_test.ipynbμ ν΄λΉ μ½λκ° ν¬ν¨λμ΄ μμ§ μκΈ° λλ¬Έμ λ°λ‘ μ€νν΄μ£Όμ΄μΌ ν©λλ€.
https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100
νμ΅ λ°μ΄ν°λ AI Hubμ 'νκ΅μ΄ κΈμ체 μ΄λ―Έμ§' λ°μ΄ν°μ
μ μ΄μ©νμ΅λλ€.
ν΄λΉ λ°μ΄ν°μ
μ μκΈμ¨, μΈμ체, μ€μ¬ λ°μ΄ν° λ‘ κ΅¬μ±λμ΄ μκ³ , μ΄ μ€ μΈμ체 λ°μ΄ν°λ κΈμ, λ¨μ΄, λ¬Έμ₯ λ°μ΄ν°λ‘ ꡬλΆλμ΄ μμ΅λλ€.
νλ‘μ νΈμμ OCR λͺ¨λΈμ inputμΌλ‘ λ€μ΄κ° λ°μ΄ν°κ° κ³μ½μμ΄κΈ° λλ¬Έμ,
νμ¬κΉμ§ νμ΅μμ μΈμ체 λ°μ΄ν° μ€ λ¬Έμ₯ λ°μ΄ν° 40,304κ°λ₯Ό μ¬μ©ν΄ νμ΅μ μ§ννμ΅λλ€.
μΆν νμ΅μ μΈμ체_λ¬Έμ₯ λ°μ΄ν°μ λν Data Augmentation μ΄λ μκΈμ¨ λ°μ΄ν°, κ·Έλ¦¬κ³ κ³μ½μ μ΄λ―Έμ§λ₯Ό λ°νμΌλ‘ μ μ²λ¦¬λ₯Ό μ§νν custom data λ₯Ό μ΄μ©ν΄ fine tuning λ° transfer learningμ μ§νν΄ λͺ¨λΈμ lossλ₯Ό λ μ€μΌ μκ°μ
λλ€.
미리 μ
λ‘λ ν΄λμ μ΄λ―Έμ§ νμΌλ€κ³Ό JSON νμΌμ μ΄μ©ν΄ GTUtility κ°μ²΄λ₯Ό μμ±ν©λλ€.
μ΄λ―Έμ§ κ°μκ° λ§μ νμ ν΄λΉ μ
μ μ½ 3~4λΆμ λ μμλ©λλ€.
μμ±λ GTUtility κ°μ²΄λ₯Ό μ΄μ©ν΄ target value(λΌλ²¨)μ μμ±ν©λλ€.
ν΄λΉ κ³Όμ μ κ°λ΅νκ² μ€λͺ
νλ©΄ λ€μκ³Ό κ°μ΅λλ€.
3) 리μ€νΈ μμ λ¬Έμμ΄μ λ¬Έμ λ¨μλ‘ μλΌμ£Όκ³ , λμ λ리μ λ£μ΄ μ€λ³΅μ μ κ±°ν©λλ€.
5) 리μ€νΈλ₯Ό λ¬Έμμ΄ ννλ‘ λ°κΏμ£Όκ³ , κ³μ½μμ μμ£Ό μ¬μ©λλ 곡백, μ«μ, .,:()[]<>"'_ λ±μ κΈ°νΈλ€μ λ¬Έμμ΄μ μΆκ°μμΌμ€λλ€.
gt_util_train, gt_util_val = gt_util.split(0.8)
Train : Validation = 8 : 2 μ λΉμ¨λ‘ λλ μ€λλ€.
λΉμ¨μ λ³κ²½νκ³ μΆμΌλ©΄ gt_util.split()
ν¨μ μμ νλΌλ―Έν°λ₯Ό μνλ train set λΉμ¨λ‘ μ€μ ν΄μ£Όλ©΄ λ©λλ€.
input_width = 256
input_height = 32
batch_size = 128
input_shape = (input_width, input_height, 1)
input μ΄λ―Έμ§μ width
μ height
, κ·Έλ¦¬κ³ batch size
λ₯Ό μ€μ ν©λλ€.
λ³Έ νμ΅μμμ inputμ λ¬Έμ₯ λ°μ΄ν°μκΈ° λλ¬Έμ widthλ₯Ό heightλ³΄λ€ ν¬κ² μ€μ ν΄μ£Όμμ΅λλ€.
batch sizeλ λ³ΈμΈμ νμ΅ νκ²½μ΄λ λͺ¨λΈ μ±λ₯μ λ°λΌ λ³κ²½ν΄μ£Όλ©΄ λ©λλ€
freeze = ['conv1_1',
'conv2_1',
'conv3_1', 'conv3_2',
#'conv4_1',
#'conv5_1',
#'conv6_1',
#'lstm1',
#'lstm2'
]
fine tuningμ μν΄ λκ²°ν LayerμΈ΅μ μ€μ ν΄μ€λλ€. μ΄ λν λͺ¨λΈ μ±λ₯μ λ§κ² μ‘°μ ν΄μ£Όλ©΄ λ©λλ€.
model, model_pred = CRNN(input_shape, len(korean_dict))
experiment = 'crnn_korean_test'
μ€μ νμ΅μμ versionλͺ μ 'crnn_korean_v1', 'crnn_korean_v2' λ±μΌλ‘ λ°κΏκ°λ©° μ€μ ν΄μ£Όμμ΅λλ€.
max_string_len = model_pred.output_shape[1]
gen_train = InputGenerator(gt_util_train, batch_size, korean_dict, input_shape[:2],
grayscale=True, max_string_len=max_string_len, concatenate=False)
gen_val = InputGenerator(gt_util_val, batch_size, korean_dict, input_shape[:2],
grayscale=True, max_string_len=max_string_len, concatenate=False)
model.load_weights('./CRNN_weights_2_v2.h5')
μ΄μ μ μ§ννλ νμ΅μ κ°μ€μΉλ₯Ό loadν΄ transfer learningμ μ§νν©λλ€.
νμ΅ μ€μ μ μ₯ν κ°μ€μΉλ₯Ό λΆλ¬μλ λκ³ , λ°λ‘ μ μ₯ν κ°μ€μΉλ₯Ό λΆλ¬μλ λ©λλ€.
λ³Έ μ½λμμ λ°λ‘ μ μ₯ν κ°μ€μΉλ₯Ό loadν΄μμ΅λλ€.
checkdir = './checkpoints/' + time.strftime('%Y%m%d%H%M') + '_' + experiment
if not os.path.exists(checkdir):
os.makedirs(checkdir)
with open(checkdir+'/source.py','wb') as f:
source = ''.join(['# In[%i]\n%s\n\n' % (i, In[i]) for i in range(len(In))])
f.write(source.encode())
μμμ μ€μ ν λͺ¨λΈ versionλͺ
μ ν λλ‘ directoryλ₯Ό μμ±ν΄ νμ΅ κ³Όμ μ μ μ₯ν©λλ€.
λ§μ½ CRNN_train_test.ipynb κ° μλ CRNN_training.py λ‘ νμ΅μ μ§ννλ€λ©΄,
νλ¨μ μ½λ λΈλμ μμ ν΄μΌ ν©λλ€.
optimizer = SGD(learning_rate=0.0001, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
Optimizerλ₯Ό μ€μ ν©λλ€.
λ³Έ μ½λμμ optimizerλ‘ SGD
λ₯Ό μ¬μ©νμΌλ, νμμ λ°λΌ Adam
κ³Ό κ°μ λ€λ₯Έ λͺ¨λΈλ μ¬μ©ν μ μμ΅λλ€.
λ§μΌ λ€λ₯Έ λͺ¨λΈμ μ¬μ©ν κ²½μ°, λ°λ‘ μ½λλ₯Ό ꡬννκ±°λ λΌμ΄λΈλ¬λ¦¬λ₯Ό λ‘λν΄μΌ ν©λλ€.
learning rate
λ 0.001λΆν° 0.0001κΉμ§ κ°μ λ³κ²½ν΄κ°λ©΄μ νμ΅μ μ§ννμ΅λλ€.
λ³ΈμΈμ μν©μ λ§κ² κ°μ λ³κ²½νλ©΄μ μ¬μ©νλ©΄ λ©λλ€.
λ³Έ νμ΅μμ λ°λ‘ κ°μ λ³κ²½νμ§ μμμΌλ, νμν κ²½μ° decay
λ momentum
μ κ°μ λ³κ²½ν μλ μμ΅λλ€.
for layer in model.layers:
layer.trainable = not layer.name in freeze
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=optimizer)
loss λͺ¨λΈλ‘λ ctc loss
λ₯Ό μ¬μ©νμμΌλ, μ΄ λν λ³κ²½ κ°λ₯ν©λλ€.
from keras.callbacks import ModelCheckpoint, EarlyStopping
hist = model.fit(gen_train.generate(),
steps_per_epoch=gt_util_train.num_objects // batch_size,
epochs=1000,
validation_data=gen_val.generate(),
validation_steps=gt_util_val.num_objects // batch_size,
callbacks=[
ModelCheckpoint(checkdir+'/weights.{epoch:03d}.h5', verbose=1, save_weights_only=True),
#ModelSnapshot(checkdir, 100),
Logger(checkdir),
EarlyStopping(monitor='val_loss', mode='auto', restore_best_weights=True, verbose=1, patience=20)
],
initial_epoch=0)
λͺ¨λΈ νμ΅μ μ§νν©λλ€.
epochs
κ°μ λ³κ²½ν μ μμ΅λλ€.
callback ν¨μλ€ λν μ€μ ν΄μ£Όμλλ°,
ModelCheckpoint
ν¨μλ₯Ό ν΅ν΄ νλμ epochκ° λλ λλ§λ€ ν΄λΉ epochμ λͺ¨λΈμ κ°μ€μΉλ₯Ό μ μ₯ν΄μ£Όμκ³ ,
EarlyStopping
ν¨μλ₯Ό ν΅ν΄ 20 epochs λμ validation lossκ° κ°μνμ§ μλλ€λ©΄ λμ΄μ νμ΅μ μ§νν νμκ° μλ€κ³ νλ¨ν΄ νμ΅μ μ€λ¨νλλ‘ νμ΅λλ€.
ν΄λΉ μ ꡬλ μ,
'[ WARN:6@537.712] global /io/opencv/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('./printed/03384889.png'): can't open/read file: check file path/integrity'
μ κ°μ warning meassageκ° λ¨λλ°,
μ΄λ JSON νμΌμμ λ¬Έμ₯ λ°μ΄ν°κ° μλ λ€λ₯Έ μ΄λ―Έμ§ λ°μ΄ν°λ€μ λν μ λ³΄κ° μ κ±°λμ§ μμκΈ° λλ¬Έμ λ¨λ λ©μΈμ§λ‘, 무μνλ©΄ λ©λλ€.
JSON νμΌμ λν μμ μ μ§ννμΌλ μλ²½νκ² μ 리λμ§ μμ μΆν μ΄ λΆλΆμ 보μν μκ°μ λλ€.
loss = hist.history['loss']
val_loss = hist.history['val_loss']
epochs = range(len(loss))
plt.figure(figsize=(15,10))
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'b',label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
ν΄λΉ μ½λλ₯Ό ν΅ν΄ epochμ λ°λ₯Έ lossκ°μ μΆμ΄λ₯Ό νμΈν μ μλ€.
## (1) Model μ μ₯ ```python model.save('CRNN_model_test.h5') ``` νμ΅ν model μ체λ₯Ό μ μ₯ν©λλ€. νλΌλ―Έν°μλ λͺ¨λΈμ΄ μ μ₯λ κ²½λ‘μ νμΌλͺ μ μ€μ ν΄μ£Όλ©΄ λ©λλ€.
## (2) Weight μ μ₯ ```python model.save_weights('CRNN_weights_test.h5') ``` νμ΅ν modelμ weight(κ°μ€μΉ)λ₯Ό μ μ₯ν©λλ€. μ΄ λν νλΌλ―Έν°μλ κ°μ€μΉκ° μ μ₯λ κ²½λ‘μ νμΌλͺ μ μ€μ ν΄μ£Όλ©΄ λ©λλ€.
import matplotlib as mpl
# μ λμ½λ κΉ¨μ§νμ ν΄κ²°
mpl.rcParams['axes.unicode_minus'] = False
# λλκ³ λ ν°νΈ μ μ©
plt.rcParams["font.family"] = 'NanumGothic'
model μ μ© κ²°κ³Όκ° νκΈμ΄κΈ° λλ¬Έμ μ λμ½λ κΉ¨μ§ νμμ ν΄κ²°νμκ³ ,
νμ΅ νκ²½μ λ°λΌ ν°νΈκ° κΉ¨μ Έμ λμ€κΈ°λ ν΄ λ°λ‘ ν°νΈλ₯Ό μ μ©ν΄μ£Όμλ€.
g = gen_val.generate()
d = next(g)
res = model_pred.predict(d[0]['image_input'])
mean_ed = 0
mean_ed_norm = 0
plot_name = 'crnn_korean'
for i in range(32):
chars = [alphabet[c] for c in np.argmax(res[i], axis=1)]
gt_str = d[0]['source_str'][i]
res_str = decode(chars)
ed = editdistance.eval(gt_str, res_str)
ed_norm = ed / len(gt_str)
mean_ed += ed
mean_ed_norm += ed_norm
img = d[0]['image_input'][i][:,:,0].T
plt.figure(figsize=[10,1.03])
plt.imshow(img, cmap='gray', interpolation=None)
ax = plt.gca()
#plt.text(0, 45, '%s' % (''.join(chars)) )
plt.text(0, 60, 'GT: %-24s RT: %-24s %0.2f' % (gt_str, res_str, ed_norm))
plt.show()
λ³Έ μ½λμμ νμ΅μ μ¬μ©νμ§ μμ validation setμμ μ΄λ―Έμ§ λ°μ΄ν°λ₯Ό κ°μ Έμ testλ₯Ό μ§ννμ΅λλ€.
μΆν validation setμ μλ custom data λν μ μ©ν΄ κ²°κ³Όλ₯Ό νμΈν μ μλλ‘ μ½λλ₯Ό μμ±ν κ³νμ
λλ€.
νμ΅ νμλ λλ ν λ¦¬κ° λ€μκ³Ό κ°μ΄ νμμΌλ‘ λ³κ²½λ©λλ€.
OCR_CRNN/
βββ checkpoints/
β βββ 202211302056_crnn_korean_v1
β β βββ history.csv
β β βββ log.csv
β β βββ weights.001.h5
β β βββ weights.002.h5
β β βββ ...
β β βββ weights.227.h5
β βββ 202212011003_crnn_korean_v2
β β βββ history.csv
β β βββ log.csv
β β βββ weights.001.h5
β β βββ weights.002.h5
β β βββ ...
β β βββ weights.037.h5
β βββ ...
β βββ 202212061250_crnn_korean_2_v1
β βββ history.csv
β βββ log.csv
β βββ weights.001.h5
β βββ weights.002.h5
β βββ ...
β βββ weights.524.h5
β
βββ printed/
β βββ 03343000.png
β βββ 03343001.png
β β ...
β βββ 03385349.png
β
βββ utils/
β βββ bboxes.py
β βββ losses.py
β βββ model.py
β βββ training.py
β
βββ CRNN_train_test.ipynb
βββ CRNN_training.py
βββ CRNN_model_2_v1.h5
βββ CRNN_model_2_v2.h5
βββ CRNN_model_test.h5
βββ CRNN_weights_2_v1.h5
βββ CRNN_weights_2_v2.h5
βββ CRNN_weights_test.h5
βββ crnn_data.py
βββ crnn_model.py
βββ crnn_utils.py
βββ ssd_data.py
βββ korean_printed_sentence.json
βββ NanumBarunGothic.ttf
βββ requirements.txt