-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdecode.py
185 lines (149 loc) · 4.4 KB
/
decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import collections
import logging
import os.path
import re
import sys
import time
import gflags
import nltk
import nltk.tokenize
from nltk.corpus import gutenberg
import viterbi
FLAGS = gflags.FLAGS
logger = logging.getLogger(__name__)
gflags.DEFINE_float(
'error_prob',
0.001,
'The probability that a letter was read or written incorrectly.')
gflags.DEFINE_bool(
'dump_tokens',
False,
'Dump a list of all tokens found in the corpora, then exit.')
class Error(Exception):
pass
TOKEN_RE = re.compile(r'[^a-z\']')
def is_token(s):
return not(TOKEN_RE.search(s))
def reposses(tokens):
last_token = None
for token in tokens:
if not last_token:
last_token = token
elif token == "'s":
yield last_token + token
last_token = None
else:
yield last_token
last_token = token
if last_token:
yield last_token
class Pdist(dict):
"A probability distribution estimated from counts."
def __init__(self, data=[], N=None, missingfn=None):
for key, count in data:
self[key] = self.get(key, 0) + int(count)
self.N = float(N or sum(self.itervalues()))
self.missingfn = missingfn or (lambda k, N: 1. / N)
def __call__(self, key):
if key in self:
return self[key] / self.N
else:
return self.missingfn(key, self.N)
CORPUS_ROOT = os.path.join(os.path.dirname(__file__), 'corpora')
class Decoder(object):
def __init__(self, corpora_ids, error_prob=None):
self.error_prob = error_prob or FLAGS.error_prob
extra_corpora = nltk.corpus.PlaintextCorpusReader(CORPUS_ROOT, '.*')
words = []
timer = Timer()
for corpus_id in corpora_ids:
logger.info('Loading corpus %s', corpus_id)
corpus = extra_corpora.raw(corpus_id)
for sentence in nltk.tokenize.sent_tokenize(corpus):
sent_words = [w.lower()
for w in reposses(nltk.tokenize.word_tokenize(sentence))]
sent_words = [filter(is_token, w) for w in sent_words]
sent_words = [w for w in sent_words if w]
if sent_words:
sent_words = ['$'] + sent_words
words += sent_words
self.states = set(words)
self.Pw = Pdist(collections.Counter(nltk.ngrams(words, 1)).items())
self.P2w = Pdist(collections.Counter(nltk.ngrams(words, 2)).items())
self.words_by_letter = collections.defaultdict(set)
for w in words:
self.words_by_letter[w[0]].update([w])
logger.info(
'Loading %s corpora took %s s', len(corpora_ids), timer.elapsed())
logger.info('%s distinct tokens', len(self.states))
def start_p(self, word):
prob = self.Pw((word,))
return prob
def transition_p(self, prev, word):
try:
a = self.P2w[(prev, word)]
b = self.Pw[(prev,)]
return a / float(b)
except KeyError:
return self.Pw((word,))
def emission_p(self, word, letter):
if word[0] == letter:
return 1.0 - self.error_prob
else:
return self.error_prob / 27.0 # a-z and $
def decode(self, initials):
timer = Timer()
states = set()
for obs in initials:
states.update(self.words_by_letter[obs])
logger.info('Searching %s possible states', len(states))
result = viterbi.viterbi(
initials,
states,
self.start_p,
self.transition_p,
self.emission_p)
logger.info('Decoding %r took %s s', initials, timer.elapsed())
return result
class Timer(object):
"""Keeps track of wall-clock time."""
def __init__(self):
self.start_time = None
self.reset()
def reset(self):
"""Resets the timer."""
self.start_time = time.time()
def elapsed(self):
"""Returns the elapsed time in seconds.
Elapsed time is the time since the timer was created or last
reset.
"""
return time.time() - self.start_time
def repl(decoder):
try:
while True:
if sys.stdin.isatty():
line = raw_input('Enter initials:\n')
else:
line = raw_input()
logger.info('Decoding %r', line)
prob, words = decoder.decode(line.lower())
print prob, ' '.join(words)
except EOFError:
pass
def dump_tokens(decoder):
for word in sorted(decoder.states):
print word
def main(argv):
args = FLAGS(argv)[1:]
logging.basicConfig(level=logging.INFO)
if not args:
sys.stderr.write('Must give corpora names\n')
sys.exit(1)
decoder = Decoder(args)
if FLAGS.dump_tokens:
dump_tokens(decoder)
else:
repl(decoder)
if __name__ == '__main__':
main(sys.argv)