-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_tools.py
66 lines (43 loc) · 1.92 KB
/
data_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# --- imports --- #
from random import sample, shuffle
from os import scandir, path, makedirs
from shutil import rmtree, copy
from glob import glob
from itertools import islice
from copy import deepcopy
import re
# --- functions --- #
def regex_replace(pattern, string, replacement='', function='extract'):
if function == 'extract':
match = re.search(pattern, string)
return match.group(1)
else:
return re.sub(pattern, replacement, string)
def create_train_test_val(img_dir, split_dir, weights=None, dir_types=('train', 'test', 'val'), n=None, total_files=10000000, rebuild=True, custom=False, custom_n=None):
if rebuild:
if path.exists(split_dir):
rmtree(split_dir)
classes = [f.name for f in scandir(img_dir) if f.is_dir() and f.name != 'dining room']
images = glob(f'{img_dir}*/*')
shuffle(images)
if not custom:
for dir_type in dir_types:
for img_class in classes:
path_string = f'{split_dir}/{dir_type}/{img_class}'
if not path.exists(path_string):
makedirs(path_string)
else:
for img_class in classes:
path_string = f'{split_dir}/{img_class}'
if not path.exists(path_string):
makedirs(path_string)
for img_class in classes:
tmp_images = [j for j in images if img_class in j]
num_files = min(total_files, len(tmp_images))
if not custom:
tmp_weights = deepcopy(n) if n else [int((i * num_files) // 1) for i in weights]
splits = [list(islice(tmp_images, i)) for i in tmp_weights]
for i, dir_type in enumerate(dir_types):
[copy(j, f'{split_dir}{dir_type}/{img_class}/' + regex_replace(r'\\([^/\\]+)$', j)) for j in splits[i]]
else:
[copy(j, f'{split_dir}/{img_class}/' + regex_replace(r'\\([^/\\]+)$', j)) for j in sample(tmp_images, custom_n)]