1
1
# -*- encoding: utf-8 -*-
2
2
# @Author: SWHL
3
3
4
+
4
5
import argparse
5
6
import base64
6
7
import io
8
+ import os
7
9
import sys
8
10
from pathlib import Path
9
11
from typing import Dict
19
21
20
22
class OCRAPIUtils :
21
23
def __init__ (self ) -> None :
22
- self .ocr = RapidOCR ()
24
+ # 从环境变量中读取参数
25
+ det_model_path = os .getenv ("det_model_path" , None )
26
+ cls_model_path = os .getenv ("cls_model_path" , None )
27
+ rec_model_path = os .getenv ("rec_model_path" , None )
28
+
29
+ self .ocr = RapidOCR (det_model_path = det_model_path , cls_model_path = cls_model_path , rec_model_path = rec_model_path )
23
30
24
- def __call__ (self , img : Image .Image ) -> Dict :
31
+ def __call__ (self , img : Image .Image , use_det = None , use_cls = None , use_rec = None ) -> Dict :
25
32
img = np .array (img )
26
- ocr_res , _ = self .ocr (img )
33
+ ocr_res , _ = self .ocr (img , use_det = use_det , use_cls = use_cls , use_rec = use_rec )
27
34
28
35
if not ocr_res :
29
36
return {}
30
37
31
- out_dict = {
32
- str (i ): {
33
- "rec_txt" : rec ,
34
- "dt_boxes" : dt_box ,
35
- "score" : f"{ score :.4f} " ,
36
- }
37
- for i , (dt_box , rec , score ) in enumerate (ocr_res )
38
- }
38
+ # 转换为字典格式: 兼容所有参数情况
39
+ out_dict = {}
40
+ for i , dats in enumerate (ocr_res ):
41
+ values = {}
42
+ for dat in dats :
43
+ if type (dat ) == str :
44
+ values ["rec_txt" ] = dat
45
+ if type (dat ) == np .float64 :
46
+ values ["score" ] = f"{ dat :.4f} "
47
+ if type (dat ) == list :
48
+ values ["dt_boxes" ] = dat
49
+ out_dict [str (i )] = values
50
+
39
51
return out_dict
40
52
41
53
42
54
app = FastAPI ()
43
55
processor = OCRAPIUtils ()
44
56
45
-
46
57
@app .get ("/" )
47
58
async def root ():
48
59
return {"message" : "Welcome to RapidOCR API Server!" }
49
60
50
-
51
61
@app .post ("/ocr" )
52
- async def ocr (image_file : UploadFile = None , image_data : str = Form (None )):
62
+ async def ocr (image_file : UploadFile = None , image_data : str = Form (None ),
63
+ use_det : bool = Form (None ), use_cls : bool = Form (None ), use_rec : bool = Form (None )):
64
+
53
65
if image_file :
54
66
img = Image .open (image_file .file )
55
67
elif image_data :
@@ -60,19 +72,17 @@ async def ocr(image_file: UploadFile = None, image_data: str = Form(None)):
60
72
raise ValueError (
61
73
"When sending a post request, data or files must have a value."
62
74
)
63
-
64
- ocr_res = processor (img )
75
+ ocr_res = processor (img , use_det = use_det , use_cls = use_cls , use_rec = use_rec )
65
76
return ocr_res
66
77
67
-
68
78
def main ():
69
79
parser = argparse .ArgumentParser ("rapidocr_api" )
70
80
parser .add_argument ("-ip" , "--ip" , type = str , default = "0.0.0.0" , help = "IP Address" )
71
81
parser .add_argument ("-p" , "--port" , type = int , default = 9003 , help = "IP port" )
82
+ parser .add_argument ('-workers' , "--workers" , type = int , default = 1 , help = 'number of worker process' )
72
83
args = parser .parse_args ()
73
84
74
- uvicorn .run ("rapidocr_api.main:app" , host = args .ip , port = args .port , reload = True )
75
-
85
+ uvicorn .run ("rapidocr_api.main:app" , host = args .ip , port = args .port , reload = 0 , workers = args .workers )
76
86
77
87
if __name__ == "__main__" :
78
88
main ()
0 commit comments