Skip to content

Commit

Permalink
Minor refactoring of utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
ajkdrag committed Mar 12, 2024
1 parent 0e2a052 commit c1d8f26
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 161 deletions.
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def parse_requirements(filename):
author_email="",
maintainer="ajkdrag",
maintainer_email="",
python_requires="==3.8.*",
install_requires=requirements,
extras_require=get_extra_requires("extra-requirements.txt"),
keywords=["ocrtoolkit"],
Expand All @@ -85,8 +84,7 @@ def parse_requirements(filename):
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
)
7 changes: 4 additions & 3 deletions src/ocrtoolkit/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .det_utils import *
from .draw_utils import *
from .ds_utils import *
from .eval_utils import *
from .geometry_utils import *
from .img_utils import *
from .io_utils import *
from .misc_utils import *
from .network_utils import *
from .geometry_utils import *
from .det_utils import *
from .model_utils import *
from .network_utils import *
62 changes: 62 additions & 0 deletions src/ocrtoolkit/utilities/box_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import List, Tuple

import numpy as np

from ocrtoolkit.utilities.geometry_utils import estimate_page_angle, rotate_boxes


def sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Sort bounding boxes from top to bottom, left to right."""
if boxes.ndim == 3: # Rotated boxes
angle = -estimate_page_angle(boxes)
boxes = rotate_boxes(
loc_preds=boxes, angle=angle, orig_shape=(1024, 1024), min_angle=5.0
)
boxes = np.concatenate((boxes.min(axis=1), boxes.max(axis=1)), axis=-1)
sort_indices = (
boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])
).argsort()
return sort_indices, boxes


def resolve_sub_lines(
boxes: np.ndarray, word_idcs: List[int], paragraph_break: float
) -> List[List[int]]:
"""Split a line in sub-lines."""
lines = []
word_idcs = sorted(word_idcs, key=lambda idx: boxes[idx, 0])

if len(word_idcs) < 2:
return [word_idcs]

sub_line = [word_idcs[0]]
for i in word_idcs[1:]:
if boxes[i, 0] - boxes[sub_line[-1], 2] < paragraph_break:
sub_line.append(i)
else:
lines.append(sub_line)
sub_line = [i]
lines.append(sub_line)
return lines


def resolve_lines(boxes: np.ndarray, paragraph_break: float) -> List[List[int]]:
"""Order boxes to group them in lines."""
idxs, boxes = sort_boxes(boxes)
y_med = np.median(boxes[:, 3] - boxes[:, 1])

lines, words, y_center_sum = [], [idxs[0]], boxes[idxs[0], [1, 3]].mean()
for idx in idxs[1:]:
y_dist = abs(boxes[idx, [1, 3]].mean() - y_center_sum / len(words))

if y_dist < y_med / 2:
words.append(idx)
y_center_sum += boxes[idx, [1, 3]].mean()
else:
lines.extend(resolve_sub_lines(boxes, words, paragraph_break))
words, y_center_sum = [idx], boxes[idx, [1, 3]].mean()

if words: # Process the last line
lines.extend(resolve_sub_lines(boxes, words, paragraph_break))

return lines
145 changes: 82 additions & 63 deletions src/ocrtoolkit/utilities/det_utils.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,82 @@
import numpy as np
from typing import List, Tuple
from ocrtoolkit.utilities.geometry_utils import (
estimate_page_angle,
rotate_boxes,
)


def sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Sort bounding boxes from top to bottom, left to right."""
if boxes.ndim == 3: # Rotated boxes
angle = -estimate_page_angle(boxes)
boxes = rotate_boxes(
loc_preds=boxes, angle=angle, orig_shape=(1024, 1024), min_angle=5.0
)
boxes = np.concatenate((boxes.min(axis=1), boxes.max(axis=1)), axis=-1)
sort_indices = (
boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])
).argsort()
return sort_indices, boxes


def resolve_sub_lines(
boxes: np.ndarray, word_idcs: List[int], paragraph_break: float
) -> List[List[int]]:
"""Split a line in sub-lines."""
lines = []
word_idcs = sorted(word_idcs, key=lambda idx: boxes[idx, 0])

if len(word_idcs) < 2:
return [word_idcs]

sub_line = [word_idcs[0]]
for i in word_idcs[1:]:
if boxes[i, 0] - boxes[sub_line[-1], 2] < paragraph_break:
sub_line.append(i)
else:
lines.append(sub_line)
sub_line = [i]
lines.append(sub_line)
return lines


def resolve_lines(boxes: np.ndarray, paragraph_break: float) -> List[List[int]]:
"""Order boxes to group them in lines."""
idxs, boxes = sort_boxes(boxes)
y_med = np.median(boxes[:, 3] - boxes[:, 1])

lines, words, y_center_sum = [], [idxs[0]], boxes[idxs[0], [1, 3]].mean()
for idx in idxs[1:]:
y_dist = abs(boxes[idx, [1, 3]].mean() - y_center_sum / len(words))

if y_dist < y_med / 2:
words.append(idx)
y_center_sum += boxes[idx, [1, 3]].mean()
else:
lines.extend(resolve_sub_lines(boxes, words, paragraph_break))
words, y_center_sum = [idx], boxes[idx, [1, 3]].mean()

if words: # Process the last line
lines.extend(resolve_sub_lines(boxes, words, paragraph_break))

return lines
import json
from pathlib import Path

import h5py
from loguru import logger


def save_dets(l_dets, path: str):
with h5py.File(path, "w") as f:
group = f.create_group("dets")
for idx, dets in enumerate(l_dets):
npy_bboxes = dets.to_numpy(encode=True)
dset = group.create_dataset(f"dets_{idx}", data=npy_bboxes)
dset.attrs["width"] = dets.width
dset.attrs["height"] = dets.height
dset.attrs["img_name"] = dets.img_name
logger.info(f"Detections saved to {path}")


def save_dets_as_label_studio(l_dets, path: str, subdir_images="images"):
"""Save detections as Label Studio json format"""
base_dir = "/data/local-files/?d={subdir_images}"
l_json_data = [
{
"data": {
"image": base_dir.format(
subdir_images=Path(subdir_images)
.joinpath(detection.img_name)
.as_posix()
),
},
"predictions": [
{
"model_version": "one",
"score": 0.5,
"result": [
{
"id": f"bbox{i+1}",
"type": "rectanglelabels",
"from_name": "label",
"to_name": "image",
"original_width": detection.width,
"original_height": detection.height,
"image_rotation": 0,
"value": {
"rotation": 0,
"x": bbox.x1 * 100,
"y": bbox.y1 * 100,
"width": bbox.w * 100,
"height": bbox.h * 100,
"rectanglelabels": [bbox.label],
},
}
for i, bbox in enumerate(detection.normalize().bboxes)
],
}
],
}
for detection in l_dets
]
with open(path, "w") as f:
json.dump(l_json_data, f, indent=2)


def load_dets(path: str):
from ocrtoolkit.wrappers.bbox import BBox
from ocrtoolkit.wrappers.detection_results import DetectionResults

with h5py.File(path, "r") as f:
l_dets = []
group = f["dets"]
dets_keys = sorted(group.keys(), key=lambda x: int(x.split("_")[-1]))
for key in dets_keys:
dets_width = int(group[key].attrs["width"])
dets_height = int(group[key].attrs["height"])
dets_img_name = str(group[key].attrs["img_name"])
dets_data = group[key][()]
l_bboxes = [BBox.from_numpy(bbox) for bbox in dets_data]
l_dets.append(
DetectionResults(l_bboxes, dets_width, dets_height, dets_img_name)
)
return l_dets
12 changes: 5 additions & 7 deletions src/ocrtoolkit/utilities/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,8 @@ def compare_dataframes(
Returns:
pd.DataFrame: DataFrame containing comparison results as percentage of matches.
pd.DataFrame: The merged dataframes used for comparison.
"""
# Check if indices are named and assign names if not
if df_a.index.name is None:
df_a.index.name = index_a
if df_b.index.name is None:
df_b.index.name = index_b

# Set indices if they are not already set
if index_a != df_a.index.name:
df_a = df_a.set_index(index_a)
Expand All @@ -56,4 +51,7 @@ def compare_dataframes(
matches = (comparison_results[col_a] == comparison_results[col_b]).mean() * 100
match_percentages[col_a[:-2]] = f"{matches:.2f}%"

return pd.DataFrame(match_percentages, index=["Match Percentage"])
return (
pd.DataFrame(match_percentages, index=["Match Percentage"]),
comparison_results,
)
3 changes: 2 additions & 1 deletion src/ocrtoolkit/utilities/geometry_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from typing import Optional, Tuple

import numpy as np


def estimate_page_angle(polys: np.ndarray) -> float:
"""Takes a batch of rotated previously
Expand Down
2 changes: 1 addition & 1 deletion src/ocrtoolkit/utilities/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

from loguru import logger
from tqdm.auto import tqdm
from tqdm.autonotebook import tqdm

from ocrtoolkit.utilities.io_utils import extract_files

Expand Down
2 changes: 1 addition & 1 deletion src/ocrtoolkit/wrappers/detection_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from ocrtoolkit.datasets.base import BaseDS
from ocrtoolkit.datasets.imageds import ImageDS
from ocrtoolkit.utilities.box_utils import resolve_lines
from ocrtoolkit.utilities.draw_utils import draw_bbox
from ocrtoolkit.utilities.misc_utils import get_samples, get_uuid
from ocrtoolkit.utilities.det_utils import resolve_lines
from ocrtoolkit.wrappers.bbox import BBox


Expand Down
82 changes: 0 additions & 82 deletions src/ocrtoolkit/wrappers/io.py

This file was deleted.

0 comments on commit c1d8f26

Please sign in to comment.