-
-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy pathrun_generate.py
66 lines (50 loc) · 1.99 KB
/
run_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import traceback
from pathlib import Path
import cv2
import fire
import pandas as pd
from tqdm.contrib.concurrent import thread_map
from manga_ocr_dev.env import FONTS_ROOT, DATA_SYNTHETIC_ROOT
from manga_ocr_dev.synthetic_data_generator.generator import SyntheticDataGenerator
generator = SyntheticDataGenerator()
def f(args):
try:
i, source, id_, text = args
filename = f"{id_}.jpg"
img, text_gt, params = generator.process(text)
cv2.imwrite(str(OUT_DIR / filename), img)
font_path = Path(params["font_path"]).relative_to(FONTS_ROOT)
ret = source, id_, text_gt, params["vertical"], str(font_path)
return ret
except Exception:
print(traceback.format_exc())
def run(package=0, n_random=1000, n_limit=None, max_workers=16):
"""
:param package: number of data package to generate
:param n_random: how many samples with random text to generate
:param n_limit: limit number of generated samples (for debugging)
:param max_workers: max number of workers
"""
package = f"{package:04d}"
lines = pd.read_csv(DATA_SYNTHETIC_ROOT / f"lines/{package}.csv")
random_lines = pd.DataFrame(
{
"source": "random",
"id": [f"random_{package}_{i}" for i in range(n_random)],
"line": None,
}
)
lines = pd.concat([lines, random_lines], ignore_index=True)
if n_limit:
lines = lines.sample(n_limit)
args = [(i, *values) for i, values in enumerate(lines.values)]
global OUT_DIR
OUT_DIR = DATA_SYNTHETIC_ROOT / "img" / package
OUT_DIR.mkdir(parents=True, exist_ok=True)
data = thread_map(f, args, max_workers=max_workers, desc=f"Processing package {package}")
data = pd.DataFrame(data, columns=["source", "id", "text", "vertical", "font_path"])
meta_path = DATA_SYNTHETIC_ROOT / f"meta/{package}.csv"
meta_path.parent.mkdir(parents=True, exist_ok=True)
data.to_csv(meta_path, index=False)
if __name__ == "__main__":
fire.Fire(run)