-
Notifications
You must be signed in to change notification settings - Fork 2
/
script.py
150 lines (117 loc) · 3.63 KB
/
script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import glob
import os
from multiprocessing import Pool
from mltools import __CPUS__
from mltools.src.log.logger import logger
from tqdm import tqdm
from random import sample
import copy
RES = []
valList = []
class FileCls:
def __init__(self, filename: str, clses: dict) -> None:
self.filename = filename
self.clses = clses
def get(self, k):
return self.clses.get(k, 0)
def __add__(self, o):
if not isinstance(o, self.__class__):
return self
nc = len(self.clses)
dic = dict()
for k in range(0, nc):
dic[k] = 0
c = FileCls("", dic)
for i in range(0, nc):
sn = self.get(i)
on = o.get(i)
c.clses[i] = sn + on
return c
def __eq__(self, o: object) -> bool:
if not isinstance(o, self.__class__):
return False
return self.filename == o.filename
def __hash__(self) -> int:
return hash(self.filename)
def get_info(filename: str, dic: dict):
with open(filename, "r") as f:
ls = f.readlines()
if len(ls) > 0:
for i in ls:
try:
clsnum = int(i.split(" ")[0])
dic[clsnum] += 1
except:
pass
return FileCls(filename, dic)
def get_sum(res: list):
if len(res) > 1:
s = res[0]
for i in range(1, len(res)):
s = s + res[i]
return s
elif len(res) == 1:
return res[0]
else:
return None
def get_mean(s: FileCls, nc: int, times: int = 10):
data = s.clses
res = []
for i in range(0, nc):
num = data.get(i, 0)
if num <= 1:
num = 0
elif num <= times:
num = 1
else:
num = round(1 / times * num)
res.append([i, num])
res.sort(key=lambda x: x[0])
return res
def get_train_val_set(resList: list, r: list):
global valList
restRES = copy.deepcopy(resList)
i = 0
while i < len(r):
if r[i][1] > 0:
lis = list(filter(lambda x: x.clses.get(r[i][0], 0) > 0, restRES))
thisValList = sample(lis, k=r[i][1])
data = get_sum(thisValList)
for j in range(0, len(data.clses)):
r[j][1] -= data.clses.get(j, 0)
thisTrainList = list(set(lis).difference(set(thisValList)))
restRES = list(set(restRES).difference(set(thisValList)))
valList.extend(thisValList)
thisValList.clear()
i += 1
def split(folder: str, savaFolder: str, nc: int = 42, multiprocesses=True):
global RES
txts = glob.glob(folder + os.sep + "*.txt")
if not len(txts) > 0:
logger.error("Folder is empty, none txt file found!")
return
trainFile = open(savaFolder + os.sep + "train.txt", "w+", encoding="utf-8")
trainvalFile = open(savaFolder + os.sep + "val.txt", "w+", encoding="utf-8")
logger.info("======== start analysing ========")
pool = Pool(__CPUS__ - 1)
pool_list = []
for t in txts:
dic = dict()
for k in range(0, nc):
dic[k] = 0
resultpool = pool.apply_async(get_info, (t, dic))
pool_list.append(resultpool)
for pr in tqdm(pool_list):
res = pr.get()
RES.append(res)
tmp = get_sum(RES)
r = get_mean(tmp, nc)
get_train_val_set(RES, r)
global valList
for i in list(set(valList)):
trainvalFile.write(i.filename + "\n")
trainList = list(set(RES).difference(set(valList)))
for i in trainList:
trainFile.write(i.filename + "\n")
trainvalFile.close()
trainFile.close()