Skip to content

Commit

Permalink
Merge branch 'master' of github.com:ajkdrag/ocrtoolkit
Browse files Browse the repository at this point in the history
  • Loading branch information
ajkdrag committed Apr 10, 2024
2 parents 53e8356 + e00eaae commit 8b51d0c
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 22 deletions.
126 changes: 104 additions & 22 deletions notebooks/saving_and_loading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"id": "f9a8c75d-11d7-4f4d-923a-054f10a60105",
"metadata": {},
"outputs": [],
Expand All @@ -13,10 +13,19 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 9,
"id": "52306290-e9e9-4e96-9fe6-84a88e84b88d",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 3 -p\n",
Expand All @@ -42,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 10,
"id": "46ddbbf0-abfc-409f-b9c4-573232839472",
"metadata": {},
"outputs": [],
Expand All @@ -52,27 +61,18 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 15,
"id": "008c7597-2e64-4ac9-b11a-938d33c3102c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ajkdrag/workspace/ocrtoolkit/src/ocrtoolkit/utilities/io_utils.py:7: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
" from tqdm.autonotebook import tqdm\n"
]
}
],
"outputs": [],
"source": [
"from ocrtoolkit.core import detect\n",
"from ocrtoolkit.core import detect, detect_and_save_h5\n",
"from ocrtoolkit.datasets import FileDS"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 12,
"id": "0a561a4e-8b57-4d7d-822f-1baff89e5208",
"metadata": {},
"outputs": [],
Expand All @@ -90,15 +90,15 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 13,
"id": "3f4c3a41-a659-409c-9c5b-199469e9030a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2024-03-18 11:35:12.315\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mocrtoolkit.wrappers.model\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m22\u001b[0m - \u001b[1mLoaded model from ../pretrained/best.pt, to cpu\u001b[0m\n"
"\u001b[32m2024-03-30 20:51:41.466\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mocrtoolkit.wrappers.model\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m22\u001b[0m - \u001b[1mLoaded model from ../pretrained/best.pt, to cpu\u001b[0m\n"
]
}
],
Expand All @@ -116,14 +116,14 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 14,
"id": "9c1677f0-c574-43dc-af60-37a1e7684598",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3c1b9667baa6472e9417093e9a2c1132",
"model_id": "383ccd8a0e964e1a8620d15b9a57235b",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -138,14 +138,46 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2024-03-18 11:35:20.645\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mocrtoolkit.utilities.io_utils\u001b[0m:\u001b[36mget_files\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mFound 235 files.\u001b[0m\n"
"\u001b[32m2024-03-30 20:51:41.729\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mocrtoolkit.utilities.io_utils\u001b[0m:\u001b[36mget_files\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mFound 235 files.\u001b[0m\n"
]
}
],
"source": [
"mini_ds = FileDS(\"../data/public/images/\").sample()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "9f41c860-517a-40f0-b76c-5a720661b56b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8073008f3e33428488bf711937963187",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/237 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2024-03-30 20:53:03.613\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mocrtoolkit.utilities.io_utils\u001b[0m:\u001b[36mget_files\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mFound 235 files.\u001b[0m\n"
]
}
],
"source": [
"medium_ds = FileDS(\"../data/public/images/\").sample(k=50)"
]
},
{
"cell_type": "markdown",
"id": "47d1abff-d792-4104-b374-d4cbf71e4a62",
Expand All @@ -164,6 +196,56 @@
"det_results = detect(detection_model, mini_ds, stream=False, verbose=False)"
]
},
{
"cell_type": "markdown",
"id": "a44b8e59-2c00-4c22-8e9a-7083c16291b0",
"metadata": {},
"source": [
"### saving detection results during inference"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "86c13d81-5430-4777-96d8-f90266745bd0",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "34f06fd9166b42f883c8cfaf7f79fd93",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2024-03-30 20:53:59.691\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mocrtoolkit.utilities.det_utils\u001b[0m:\u001b[36msave_dets\u001b[0m:\u001b[36m17\u001b[0m - \u001b[1mDetections saved to temp/medium_dets_save.h5\u001b[0m\n"
]
}
],
"source": [
"detect_and_save_h5(detection_model, medium_ds, \"temp/medium_dets_save.h5\", bs=5, verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "d3afe405-7928-40c6-a038-7df75fb22fc9",
"metadata": {},
"outputs": [],
"source": [
"loaded_medium_dets = load_dets(\"temp/medium_dets_save.h5\")"
]
},
{
"cell_type": "markdown",
"id": "2716e2d9-defa-4ab1-9bbd-722112be6b00",
Expand Down
25 changes: 25 additions & 0 deletions src/ocrtoolkit/core/detector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from loguru import logger
from tqdm.auto import tqdm

from ocrtoolkit.datasets.base import BaseDS
from ocrtoolkit.utilities.det_utils import save_dets
from ocrtoolkit.wrappers.model import DetectionModel


Expand Down Expand Up @@ -34,3 +36,26 @@ def detect(model: DetectionModel, ds: BaseDS, stream=True, **kwargs):
if stream:
return gen
return list(_detect(model, ds, **kwargs))


def detect_and_save_h5(
model: DetectionModel,
ds: BaseDS,
path: str,
bs=4,
start_batch_idx=0,
**kwargs,
):
"""Detects objects in a dataset
Call model.preprocess methods before model.predict methods
Images should be converted to np.ndarray before calling preprocess
Saves the detections to path.
Caution: This captures all detection results in mem before saving.
"""
num_batches = ds.num_batches(bs)
l_det_results = []
for idx in tqdm(range(start_batch_idx, num_batches)):
batch = ds.batch(bs, idx)
l_det_results += detect(model, batch, stream=False, **kwargs)

save_dets(l_det_results, path)

0 comments on commit 8b51d0c

Please sign in to comment.