Skip to content

Commit

Permalink
feat(rapidocr_api): update version to v0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Oct 28, 2024
1 parent 3aa4463 commit 4ddd44f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 35 deletions.
20 changes: 11 additions & 9 deletions api/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
### See [Documentation](https://rapidai.github.io/RapidOCRDocs/install_usage/rapidocr_api/usage/)


### API 修改说明

* uvicorn启动时,reload参数设置为False,避免反复加载;
* 增加了启动参数: workers,可启动多个实例,以满足多并发需求。
* 增加了启动参数: workers,可启动多个实例,以满足多并发需求。
* 可通过环境变量传递模型参数:det_model_path, cls_model_path, rec_model_path;
* 接口中可传入参数,控制是否使用检测、方向分类和识别这三部分的模型;客户端调用见`demo.py`
* 增加了Dockerfile,可自行构建镜像。
Expand All @@ -13,7 +12,7 @@

Windows下启动:

```for win shell
```shell
set det_model_path=I:\models\图像相关\OCR\RapidOCR\PP-OCRv4\ch_PP-OCRv4_det_server_infer.onnx
set det_model_path=

Expand All @@ -22,6 +21,7 @@ rapidocr_api
```

Linux下启动:

```shell
# 默认参数启动
rapidocr_api
Expand All @@ -30,25 +30,27 @@ rapidocr_api
rapidocr_api -ip 0.0.0.0 -p 9005 -workers 2

# 指定模型
expert det_model_path=/mnt/sda1/models/PP-OCRv4/ch_PP-OCRv4_det_server_infer.onnx
expert det_model_path=/mnt/sda1/models/PP-OCRv4/ch_PP-OCRv4_det_server_infer.onnx
expert rec_model_path=/mnt/sda1/models/PP-OCRv4/ch_PP-OCRv4_rec_server_infer.onnx
rapidocr_api -ip 0.0.0.0 -p 9005 -workers 2
```


客户端调用说明:
```

```bash
cd api
python demo.py
```

构建镜像:
```

```bash
cd api
sudo docker build -t="rapidocr_api:0.1.1" .
```

启动镜像:

```
```bash
docker run -p 9003:9003 --name rapidocr_api1 --restart always -d rapidocr_api:0.1.1
```
```
31 changes: 16 additions & 15 deletions api/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

# print(response.text)

import time
import base64
import time

import requests

url = "http://localhost:9003/ocr"
Expand All @@ -26,36 +27,36 @@
img_str = base64.b64encode(fa.read())

payload = {"image_data": img_str}
response = requests.post(url, data=payload) #, timeout=60
response = requests.post(url, data=payload) # , timeout=60

print(response.json())
etime = time.time() - stime
print(f'用时:{etime:.3f}')
print(f"用时:{etime:.3f}")

print('-'*40)
print("-" * 40)

# 方式二:使用文件上传方式
stime = time.time()
with open(img_path, 'rb') as f:
file_dict = {'image_file': (img_path, f, 'image/png')}
response = requests.post(url, files=file_dict) #, timeout=60
with open(img_path, "rb") as f:
file_dict = {"image_file": (img_path, f, "image/png")}
response = requests.post(url, files=file_dict) # , timeout=60
print(response.json())

etime = time.time() - stime
print(f'用时:{etime:.3f}')
print('-'*40)
print(f"用时:{etime:.3f}")
print("-" * 40)

# 方式三:控制是否使用检测、方向分类和识别这三部分的模型; 不使用检测模型:use_det=False
stime = time.time()
img_path = "../python/tests/test_files/test_without_det.jpg"

with open(img_path, 'rb') as f:
file_dict = {'image_file': (img_path, f, 'image/png')}
with open(img_path, "rb") as f:
file_dict = {"image_file": (img_path, f, "image/png")}
# 添加控制参数
data = {"use_det":False, "use_cls":True, "use_rec":True}
response = requests.post(url, files=file_dict, data=data) #, timeout=60
data = {"use_det": False, "use_cls": True, "use_rec": True}
response = requests.post(url, files=file_dict, data=data) # , timeout=60
print(response.json())

etime = time.time() - stime
print(f'用时:{etime:.3f}')
print('-'*40)
print(f"用时:{etime:.3f}")
print("-" * 40)
48 changes: 37 additions & 11 deletions api/rapidocr_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,19 @@ def __init__(self) -> None:
cls_model_path = os.getenv("cls_model_path", None)
rec_model_path = os.getenv("rec_model_path", None)

self.ocr = RapidOCR(det_model_path=det_model_path, cls_model_path=cls_model_path, rec_model_path=rec_model_path)
self.ocr = RapidOCR(
det_model_path=det_model_path,
cls_model_path=cls_model_path,
rec_model_path=rec_model_path,
)

def __call__(self, img: Image.Image, use_det=None, use_cls=None, use_rec=None) -> Dict:
def __call__(
self, img: Image.Image, use_det=None, use_cls=None, use_rec=None, **kwargs
) -> Dict:
img = np.array(img)
ocr_res, _ = self.ocr(img, use_det=use_det, use_cls=use_cls, use_rec=use_rec)
ocr_res, _ = self.ocr(
img, use_det=use_det, use_cls=use_cls, use_rec=use_rec, **kwargs
)

if not ocr_res:
return {}
Expand All @@ -40,11 +48,13 @@ def __call__(self, img: Image.Image, use_det=None, use_cls=None, use_rec=None) -
for i, dats in enumerate(ocr_res):
values = {}
for dat in dats:
if type(dat) == str:
if isinstance(dat, str):
values["rec_txt"] = dat
if type(dat) == np.float64:

if isinstance(dat, np.float64):
values["score"] = f"{dat:.4f}"
if type(dat) == list:

if isinstance(dat, list):
values["dt_boxes"] = dat
out_dict[str(i)] = values

Expand All @@ -54,14 +64,20 @@ def __call__(self, img: Image.Image, use_det=None, use_cls=None, use_rec=None) -
app = FastAPI()
processor = OCRAPIUtils()


@app.get("/")
async def root():
return {"message": "Welcome to RapidOCR API Server!"}

@app.post("/ocr")
async def ocr(image_file: UploadFile = None, image_data: str = Form(None),
use_det: bool = Form(None), use_cls: bool = Form(None), use_rec: bool = Form(None)):

@app.post("/ocr")
async def ocr(
image_file: UploadFile = None,
image_data: str = Form(None),
use_det: bool = Form(None),
use_cls: bool = Form(None),
use_rec: bool = Form(None),
):
if image_file:
img = Image.open(image_file.file)
elif image_data:
Expand All @@ -75,14 +91,24 @@ async def ocr(image_file: UploadFile = None, image_data: str = Form(None),
ocr_res = processor(img, use_det=use_det, use_cls=use_cls, use_rec=use_rec)
return ocr_res


def main():
parser = argparse.ArgumentParser("rapidocr_api")
parser.add_argument("-ip", "--ip", type=str, default="0.0.0.0", help="IP Address")
parser.add_argument("-p", "--port", type=int, default=9003, help="IP port")
parser.add_argument('-workers', "--workers", type=int, default=1, help='number of worker process')
parser.add_argument(
"-workers", "--workers", type=int, default=1, help="number of worker process"
)
args = parser.parse_args()

uvicorn.run("rapidocr_api.main:app", host=args.ip, port=args.port, reload=0, workers=args.workers)
uvicorn.run(
"rapidocr_api.main:app",
host=args.ip,
port=args.port,
reload=0,
workers=args.workers,
)


if __name__ == "__main__":
main()

0 comments on commit 4ddd44f

Please sign in to comment.