-
Notifications
You must be signed in to change notification settings - Fork 5
/
predict_d3.py
120 lines (92 loc) · 4.54 KB
/
predict_d3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os
from collections import Counter
import torch as T
import yaml
from pyarabic.araby import tokenize
from tqdm import tqdm
from torch.utils.data import DataLoader
from model_d3 import DiacritizerD3
from data_utils import DatasetUtils
from dataloader import DataRetriever
DEVICE = 'cuda'
TEST_FILE = "test"
class Predictor:
def __init__(self, config):
self.data_utils = DatasetUtils(config)
vocab_size = len(self.data_utils.letter_list)
word_embeddings = self.data_utils.embeddings
self.mapping = self.data_utils.load_mapping_v3(TEST_FILE)
self.original_lines = self.data_utils.load_file_clean(TEST_FILE, strip=True)
self.model = DiacritizerD3(config, device=DEVICE)
self.model.build(word_embeddings, vocab_size)
state_dict = T.load(config["paths"]["load"], map_location=T.device(DEVICE))['state_dict']
self.model.load_state_dict(state_dict)
self.model.to(DEVICE)
self.model.eval()
testset = DataRetriever("test", self.data_utils, is_test=True)
self.data_loader = DataLoader(testset,
batch_size=min(config["predictor"]["batch-size"], 128),
shuffle=False,
num_workers=config["loader"]["num-workers"]
)
class PredictTri(Predictor):
def __init__(self, config):
super().__init__(config)
self.diacritics = {
"FATHA": 1,
"KASRA": 2,
"DAMMA": 3,
"SUKUN": 4
}
def shakkel_char(self, diac: int, tanween: bool, shadda: bool) -> str:
returned_text = ""
if shadda and diac != self.diacritics["SUKUN"]:
returned_text += "\u0651"
if diac == self.diacritics["FATHA"]:
returned_text += "\u064E" if not tanween else "\u064B"
elif diac == self.diacritics["KASRA"]:
returned_text += "\u0650" if not tanween else "\u064D"
elif diac == self.diacritics["DAMMA"]:
returned_text += "\u064F" if not tanween else "\u064C"
elif diac == self.diacritics["SUKUN"]:
returned_text += "\u0652"
return returned_text
def predict_mv(self):
y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader)
diacritized_lines = []
for sent_idx, line in tqdm(enumerate(self.original_lines), total=len(self.original_lines)):
diacritized_line = ""
line = ' '.join(tokenize(line))
for char_idx, char in enumerate(line):
diacritized_line += char
# mapping: [seg_idx][t_idx][c_idx] --> [sent_idx][char_idx]
char_vote_haraka, char_vote_shadda, char_vote_tanween = [], [], []
if sent_idx not in self.mapping: continue
for seg_idx in self.mapping[sent_idx]:
for t_idx in self.mapping[sent_idx][seg_idx]:
if char_idx in self.mapping[sent_idx][seg_idx][t_idx]:
c_idx = self.mapping[sent_idx][seg_idx][t_idx].index(char_idx)
char_vote_haraka += [y_gen_diac[seg_idx][t_idx][c_idx]]
char_vote_shadda += [y_gen_shadda[seg_idx][t_idx][c_idx]]
char_vote_tanween += [y_gen_tanween[seg_idx][t_idx][c_idx]]
if len(char_vote_haraka) > 0:
char_mv_diac = Counter(char_vote_haraka).most_common()[0][0]
char_mv_shadda = Counter(char_vote_shadda).most_common()[0][0]
char_mv_tanween = Counter(char_vote_tanween).most_common()[0][0]
diacritized_line += self.shakkel_char(char_mv_diac, char_mv_tanween, char_mv_shadda)
diacritized_lines += [diacritized_line.strip()]
return diacritized_lines
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Paramaters')
parser.add_argument('-c', '--config', type=str,
default="configs/config_d3.yaml", help='path of config file')
args = parser.parse_args()
with open(args.config, 'r', encoding="utf-8") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
config["train"]["max-sent-len"] = config["predictor"]["window"]
predictor = PredictTri(config)
diacritized_lines = predictor.predict_mv()
exp_id = config["run-title"].split("-")[-1].lower()
with open(os.path.join(config["paths"]["base"], 'preds', f'predictions_{exp_id}.txt'), 'w', encoding='utf-8') as fout:
fout.write('\n'.join(diacritized_lines))