Skip to content

Commit

Permalink
feat(rapidocr_api): support mutli rapidocr
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Dec 4, 2024
1 parent 43efc5d commit ccbb8aa
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
22 changes: 11 additions & 11 deletions api/rapidocr_api/main.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,48 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]

import argparse
import base64
import importlib.util
import io
import os
import sys
from pathlib import Path
from typing import Dict
import importlib.util

import numpy as np
import uvicorn
from fastapi import FastAPI, Form, UploadFile
from PIL import Image

if importlib.util.find_spec("rapidocr_runtime"):
if importlib.util.find_spec("rapidocr_onnxruntime"):
from rapidocr_onnxruntime import RapidOCR
elif importlib.util.find_spec("rapidocr_paddle"):
from rapidocr_paddle import RapidOCR
elif importlib.util.find_spec("rapidocr_openvino"):
from rapidocr_openvino import RapidOCR
else:
raise ImportError(
"Pleas install one of [rapidocr-runtime,rapidocr-paddle,rapidocr-openvino]"
"Please install one of [rapidocr_onnxruntime,rapidocr-paddle,rapidocr-openvino]"
)

sys.path.append(str(Path(__file__).resolve().parent.parent))


class OCRAPIUtils:
def __init__(self) -> None:
# 从环境变量中读取参数
det_model_path = os.getenv("det_model_path", None)
cls_model_path = os.getenv("cls_model_path", None)
rec_model_path = os.getenv("rec_model_path", None)

self.ocr = RapidOCR(
det_model_path=det_model_path,
cls_model_path=cls_model_path,
rec_model_path=rec_model_path,
)
if det_model_path is None or cls_model_path is None or rec_model_path is None:
self.ocr = RapidOCR()
else:
self.ocr = RapidOCR(
det_model_path=det_model_path,
cls_model_path=cls_model_path,
rec_model_path=rec_model_path,
)

def __call__(
self, img: Image.Image, use_det=None, use_cls=None, use_rec=None, **kwargs
Expand All @@ -54,7 +55,6 @@ def __call__(
if not ocr_res:
return {}

# 转换为字典格式: 兼容所有参数情况
out_dict = {}
for i, dats in enumerate(ocr_res):
values = {}
Expand Down
12 changes: 6 additions & 6 deletions api/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_readme():
latest_version = obtainer(MODULE_NAME)
except ValueError:
latest_version = "0.0.1"
VERSION_NUM = obtainer.version_add_one(latest_version)
VERSION_NUM = obtainer.version_add_one(latest_version, add_patch=True)

if len(sys.argv) > 2:
match_str = " ".join(sys.argv[2:])
Expand All @@ -56,11 +56,6 @@ def get_readme():
license="Apache-2.0",
include_package_data=True,
install_requires=read_txt("requirements.txt"),
extras_require={
'onnx': ['rapidocr-onnxruntime'],
'paddle': ['rapidocr-paddle'],
'openvino': ['rapidocr-openvino'],
},
packages=[MODULE_NAME],
package_data={"": ["*.ico", "*.css", "*.js", "*.html"]},
keywords=[
Expand All @@ -81,4 +76,9 @@ def get_readme():
f"{MODULE_NAME}={MODULE_NAME}.main:main",
],
},
extras_require={
"onnx": ["rapidocr-onnxruntime"],
"paddle": ["rapidocr-paddle"],
"openvino": ["rapidocr-openvino"],
},
)

0 comments on commit ccbb8aa

Please sign in to comment.