diff --git a/setup.py b/setup.py index 592661a..8f3cf4c 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,6 @@ def parse_requirements(filename): author_email="", maintainer="ajkdrag", maintainer_email="", - python_requires="==3.8.*", install_requires=requirements, extras_require=get_extra_requires("extra-requirements.txt"), keywords=["ocrtoolkit"], @@ -85,8 +84,7 @@ def parse_requirements(filename): classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], ) diff --git a/src/ocrtoolkit/utilities/__init__.py b/src/ocrtoolkit/utilities/__init__.py index 257b654..0e7011c 100644 --- a/src/ocrtoolkit/utilities/__init__.py +++ b/src/ocrtoolkit/utilities/__init__.py @@ -1,9 +1,10 @@ +from .det_utils import * from .draw_utils import * from .ds_utils import * +from .eval_utils import * +from .geometry_utils import * from .img_utils import * from .io_utils import * from .misc_utils import * -from .network_utils import * -from .geometry_utils import * -from .det_utils import * from .model_utils import * +from .network_utils import * diff --git a/src/ocrtoolkit/utilities/box_utils.py b/src/ocrtoolkit/utilities/box_utils.py new file mode 100644 index 0000000..ea3c2cf --- /dev/null +++ b/src/ocrtoolkit/utilities/box_utils.py @@ -0,0 +1,62 @@ +from typing import List, Tuple + +import numpy as np + +from ocrtoolkit.utilities.geometry_utils import estimate_page_angle, rotate_boxes + + +def sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Sort bounding boxes from top to bottom, left to right.""" + if boxes.ndim == 3: # Rotated boxes + angle = -estimate_page_angle(boxes) + boxes = rotate_boxes( + loc_preds=boxes, angle=angle, orig_shape=(1024, 1024), min_angle=5.0 + ) + boxes = np.concatenate((boxes.min(axis=1), boxes.max(axis=1)), axis=-1) + sort_indices = ( + boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1]) + ).argsort() + return sort_indices, boxes + + +def resolve_sub_lines( + boxes: np.ndarray, word_idcs: List[int], paragraph_break: float +) -> List[List[int]]: + """Split a line in sub-lines.""" + lines = [] + word_idcs = sorted(word_idcs, key=lambda idx: boxes[idx, 0]) + + if len(word_idcs) < 2: + return [word_idcs] + + sub_line = [word_idcs[0]] + for i in word_idcs[1:]: + if boxes[i, 0] - boxes[sub_line[-1], 2] < paragraph_break: + sub_line.append(i) + else: + lines.append(sub_line) + sub_line = [i] + lines.append(sub_line) + return lines + + +def resolve_lines(boxes: np.ndarray, paragraph_break: float) -> List[List[int]]: + """Order boxes to group them in lines.""" + idxs, boxes = sort_boxes(boxes) + y_med = np.median(boxes[:, 3] - boxes[:, 1]) + + lines, words, y_center_sum = [], [idxs[0]], boxes[idxs[0], [1, 3]].mean() + for idx in idxs[1:]: + y_dist = abs(boxes[idx, [1, 3]].mean() - y_center_sum / len(words)) + + if y_dist < y_med / 2: + words.append(idx) + y_center_sum += boxes[idx, [1, 3]].mean() + else: + lines.extend(resolve_sub_lines(boxes, words, paragraph_break)) + words, y_center_sum = [idx], boxes[idx, [1, 3]].mean() + + if words: # Process the last line + lines.extend(resolve_sub_lines(boxes, words, paragraph_break)) + + return lines diff --git a/src/ocrtoolkit/utilities/det_utils.py b/src/ocrtoolkit/utilities/det_utils.py index 9b812a8..9a284eb 100644 --- a/src/ocrtoolkit/utilities/det_utils.py +++ b/src/ocrtoolkit/utilities/det_utils.py @@ -1,63 +1,82 @@ -import numpy as np -from typing import List, Tuple -from ocrtoolkit.utilities.geometry_utils import ( - estimate_page_angle, - rotate_boxes, -) - - -def sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Sort bounding boxes from top to bottom, left to right.""" - if boxes.ndim == 3: # Rotated boxes - angle = -estimate_page_angle(boxes) - boxes = rotate_boxes( - loc_preds=boxes, angle=angle, orig_shape=(1024, 1024), min_angle=5.0 - ) - boxes = np.concatenate((boxes.min(axis=1), boxes.max(axis=1)), axis=-1) - sort_indices = ( - boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1]) - ).argsort() - return sort_indices, boxes - - -def resolve_sub_lines( - boxes: np.ndarray, word_idcs: List[int], paragraph_break: float -) -> List[List[int]]: - """Split a line in sub-lines.""" - lines = [] - word_idcs = sorted(word_idcs, key=lambda idx: boxes[idx, 0]) - - if len(word_idcs) < 2: - return [word_idcs] - - sub_line = [word_idcs[0]] - for i in word_idcs[1:]: - if boxes[i, 0] - boxes[sub_line[-1], 2] < paragraph_break: - sub_line.append(i) - else: - lines.append(sub_line) - sub_line = [i] - lines.append(sub_line) - return lines - - -def resolve_lines(boxes: np.ndarray, paragraph_break: float) -> List[List[int]]: - """Order boxes to group them in lines.""" - idxs, boxes = sort_boxes(boxes) - y_med = np.median(boxes[:, 3] - boxes[:, 1]) - - lines, words, y_center_sum = [], [idxs[0]], boxes[idxs[0], [1, 3]].mean() - for idx in idxs[1:]: - y_dist = abs(boxes[idx, [1, 3]].mean() - y_center_sum / len(words)) - - if y_dist < y_med / 2: - words.append(idx) - y_center_sum += boxes[idx, [1, 3]].mean() - else: - lines.extend(resolve_sub_lines(boxes, words, paragraph_break)) - words, y_center_sum = [idx], boxes[idx, [1, 3]].mean() - - if words: # Process the last line - lines.extend(resolve_sub_lines(boxes, words, paragraph_break)) - - return lines +import json +from pathlib import Path + +import h5py +from loguru import logger + + +def save_dets(l_dets, path: str): + with h5py.File(path, "w") as f: + group = f.create_group("dets") + for idx, dets in enumerate(l_dets): + npy_bboxes = dets.to_numpy(encode=True) + dset = group.create_dataset(f"dets_{idx}", data=npy_bboxes) + dset.attrs["width"] = dets.width + dset.attrs["height"] = dets.height + dset.attrs["img_name"] = dets.img_name + logger.info(f"Detections saved to {path}") + + +def save_dets_as_label_studio(l_dets, path: str, subdir_images="images"): + """Save detections as Label Studio json format""" + base_dir = "/data/local-files/?d={subdir_images}" + l_json_data = [ + { + "data": { + "image": base_dir.format( + subdir_images=Path(subdir_images) + .joinpath(detection.img_name) + .as_posix() + ), + }, + "predictions": [ + { + "model_version": "one", + "score": 0.5, + "result": [ + { + "id": f"bbox{i+1}", + "type": "rectanglelabels", + "from_name": "label", + "to_name": "image", + "original_width": detection.width, + "original_height": detection.height, + "image_rotation": 0, + "value": { + "rotation": 0, + "x": bbox.x1 * 100, + "y": bbox.y1 * 100, + "width": bbox.w * 100, + "height": bbox.h * 100, + "rectanglelabels": [bbox.label], + }, + } + for i, bbox in enumerate(detection.normalize().bboxes) + ], + } + ], + } + for detection in l_dets + ] + with open(path, "w") as f: + json.dump(l_json_data, f, indent=2) + + +def load_dets(path: str): + from ocrtoolkit.wrappers.bbox import BBox + from ocrtoolkit.wrappers.detection_results import DetectionResults + + with h5py.File(path, "r") as f: + l_dets = [] + group = f["dets"] + dets_keys = sorted(group.keys(), key=lambda x: int(x.split("_")[-1])) + for key in dets_keys: + dets_width = int(group[key].attrs["width"]) + dets_height = int(group[key].attrs["height"]) + dets_img_name = str(group[key].attrs["img_name"]) + dets_data = group[key][()] + l_bboxes = [BBox.from_numpy(bbox) for bbox in dets_data] + l_dets.append( + DetectionResults(l_bboxes, dets_width, dets_height, dets_img_name) + ) + return l_dets diff --git a/src/ocrtoolkit/utilities/eval_utils.py b/src/ocrtoolkit/utilities/eval_utils.py index 59fcd8d..48ef871 100644 --- a/src/ocrtoolkit/utilities/eval_utils.py +++ b/src/ocrtoolkit/utilities/eval_utils.py @@ -25,13 +25,8 @@ def compare_dataframes( Returns: pd.DataFrame: DataFrame containing comparison results as percentage of matches. + pd.DataFrame: The merged dataframes used for comparison. """ - # Check if indices are named and assign names if not - if df_a.index.name is None: - df_a.index.name = index_a - if df_b.index.name is None: - df_b.index.name = index_b - # Set indices if they are not already set if index_a != df_a.index.name: df_a = df_a.set_index(index_a) @@ -56,4 +51,7 @@ def compare_dataframes( matches = (comparison_results[col_a] == comparison_results[col_b]).mean() * 100 match_percentages[col_a[:-2]] = f"{matches:.2f}%" - return pd.DataFrame(match_percentages, index=["Match Percentage"]) + return ( + pd.DataFrame(match_percentages, index=["Match Percentage"]), + comparison_results, + ) diff --git a/src/ocrtoolkit/utilities/geometry_utils.py b/src/ocrtoolkit/utilities/geometry_utils.py index da441e1..2f8b1cd 100644 --- a/src/ocrtoolkit/utilities/geometry_utils.py +++ b/src/ocrtoolkit/utilities/geometry_utils.py @@ -1,6 +1,7 @@ -import numpy as np from typing import Optional, Tuple +import numpy as np + def estimate_page_angle(polys: np.ndarray) -> float: """Takes a batch of rotated previously diff --git a/src/ocrtoolkit/utilities/network_utils.py b/src/ocrtoolkit/utilities/network_utils.py index e383c87..45c8122 100644 --- a/src/ocrtoolkit/utilities/network_utils.py +++ b/src/ocrtoolkit/utilities/network_utils.py @@ -5,7 +5,7 @@ from typing import Optional from loguru import logger -from tqdm.auto import tqdm +from tqdm.autonotebook import tqdm from ocrtoolkit.utilities.io_utils import extract_files diff --git a/src/ocrtoolkit/wrappers/detection_results.py b/src/ocrtoolkit/wrappers/detection_results.py index a34f316..46068b8 100644 --- a/src/ocrtoolkit/wrappers/detection_results.py +++ b/src/ocrtoolkit/wrappers/detection_results.py @@ -5,9 +5,9 @@ from ocrtoolkit.datasets.base import BaseDS from ocrtoolkit.datasets.imageds import ImageDS +from ocrtoolkit.utilities.box_utils import resolve_lines from ocrtoolkit.utilities.draw_utils import draw_bbox from ocrtoolkit.utilities.misc_utils import get_samples, get_uuid -from ocrtoolkit.utilities.det_utils import resolve_lines from ocrtoolkit.wrappers.bbox import BBox diff --git a/src/ocrtoolkit/wrappers/io.py b/src/ocrtoolkit/wrappers/io.py deleted file mode 100644 index 0af0776..0000000 --- a/src/ocrtoolkit/wrappers/io.py +++ /dev/null @@ -1,82 +0,0 @@ -import json -from pathlib import Path - -import h5py -from loguru import logger - -from ocrtoolkit.wrappers.bbox import BBox -from ocrtoolkit.wrappers.detection_results import DetectionResults - - -def save_dets(l_dets, path: str): - with h5py.File(path, "w") as f: - group = f.create_group("dets") - for idx, dets in enumerate(l_dets): - npy_bboxes = dets.to_numpy(encode=True) - dset = group.create_dataset(f"dets_{idx}", data=npy_bboxes) - dset.attrs["width"] = dets.width - dset.attrs["height"] = dets.height - dset.attrs["img_name"] = dets.img_name - logger.info(f"Detections saved to {path}") - - -def save_dets_as_label_studio(l_dets, path: str, subdir_images="images"): - """Save detections as Label Studio json format""" - base_dir = "/data/local-files/?d={subdir_images}" - l_json_data = [ - { - "data": { - "image": base_dir.format( - subdir_images=Path(subdir_images) - .joinpath(detection.img_name) - .as_posix() - ), - }, - "predictions": [ - { - "model_version": "one", - "score": 0.5, - "result": [ - { - "id": f"bbox{i+1}", - "type": "rectanglelabels", - "from_name": "label", - "to_name": "image", - "original_width": detection.width, - "original_height": detection.height, - "image_rotation": 0, - "value": { - "rotation": 0, - "x": bbox.x1 * 100, - "y": bbox.y1 * 100, - "width": bbox.w * 100, - "height": bbox.h * 100, - "rectanglelabels": [bbox.label], - }, - } - for i, bbox in enumerate(detection.normalize().bboxes) - ], - } - ], - } - for detection in l_dets - ] - with open(path, "w") as f: - json.dump(l_json_data, f, indent=2) - - -def load_dets(path: str): - with h5py.File(path, "r") as f: - l_dets = [] - group = f["dets"] - dets_keys = sorted(group.keys(), key=lambda x: int(x.split("_")[-1])) - for key in dets_keys: - dets_width = int(group[key].attrs["width"]) - dets_height = int(group[key].attrs["height"]) - dets_img_name = str(group[key].attrs["img_name"]) - dets_data = group[key][()] - l_bboxes = [BBox.from_numpy(bbox) for bbox in dets_data] - l_dets.append( - DetectionResults(l_bboxes, dets_width, dets_height, dets_img_name) - ) - return l_dets