7
7
import string
8
8
from collections import Counter
9
9
10
- from .utils import load_file , write_file , _parse_into_words
10
+ from .utils import load_file , write_file , _parse_into_words , ENSURE_UNICODE
11
11
12
12
13
13
class SpellChecker (object ):
@@ -62,10 +62,12 @@ def __init__(
62
62
63
63
def __contains__ (self , key ):
64
64
""" setup easier known checks """
65
+ key = ENSURE_UNICODE (key )
65
66
return key in self ._word_frequency
66
67
67
68
def __getitem__ (self , key ):
68
69
""" setup easier frequency checks """
70
+ key = ENSURE_UNICODE (key )
69
71
return self ._word_frequency [key ]
70
72
71
73
@property
@@ -105,6 +107,7 @@ def split_words(self, text):
105
107
text (str): The text to split into individual words
106
108
Returns:
107
109
list(str): A listing of all words in the provided text """
110
+ text = ENSURE_UNICODE (text )
108
111
return self ._tokenizer (text )
109
112
110
113
def export (self , filepath , encoding = "utf-8" , gzipped = True ):
@@ -131,6 +134,7 @@ def word_probability(self, word, total_words=None):
131
134
float: The probability that the word is the correct word """
132
135
if total_words is None :
133
136
total_words = self ._word_frequency .total_words
137
+ word = ENSURE_UNICODE (word )
134
138
return self ._word_frequency .dictionary [word ] / total_words
135
139
136
140
def correction (self , word ):
@@ -140,6 +144,7 @@ def correction(self, word):
140
144
word (str): The word to correct
141
145
Returns:
142
146
str: The most likely candidate """
147
+ word = ENSURE_UNICODE (word )
143
148
candidates = list (self .candidates (word ))
144
149
return max (sorted (candidates ), key = self .word_probability )
145
150
@@ -151,6 +156,7 @@ def candidates(self, word):
151
156
word (str): The word for which to calculate candidate spellings
152
157
Returns:
153
158
set: The set of words that are possible candidates """
159
+ word = ENSURE_UNICODE (word )
154
160
if self .known ([word ]): # short-cut if word is correct already
155
161
return {word }
156
162
# get edit distance 1...
@@ -174,6 +180,7 @@ def known(self, words):
174
180
Returns:
175
181
set: The set of those words from the input that are in the \
176
182
corpus """
183
+ words = [ENSURE_UNICODE (w ) for w in words ]
177
184
tmp = [w if self ._case_sensitive else w .lower () for w in words ]
178
185
return set (
179
186
w
@@ -191,6 +198,7 @@ def unknown(self, words):
191
198
Returns:
192
199
set: The set of those words from the input that are not in \
193
200
the corpus """
201
+ words = [ENSURE_UNICODE (w ) for w in words ]
194
202
tmp = [
195
203
w if self ._case_sensitive else w .lower ()
196
204
for w in words
@@ -207,7 +215,7 @@ def edit_distance_1(self, word):
207
215
Returns:
208
216
set: The set of strings that are edit distance one from the \
209
217
provided word """
210
- word = word .lower ()
218
+ word = ENSURE_UNICODE ( word ) .lower ()
211
219
if self ._check_if_should_check (word ) is False :
212
220
return {word }
213
221
letters = self ._word_frequency .letters
@@ -227,7 +235,7 @@ def edit_distance_2(self, word):
227
235
Returns:
228
236
set: The set of strings that are edit distance two from the \
229
237
provided word """
230
- word = word .lower ()
238
+ word = ENSURE_UNICODE ( word ) .lower ()
231
239
return [
232
240
e2 for e1 in self .edit_distance_1 (word ) for e2 in self .edit_distance_1 (e1 )
233
241
]
@@ -241,8 +249,13 @@ def __edit_distance_alt(self, words):
241
249
Returns:
242
250
set: The set of strings that are edit distance two from the \
243
251
provided words """
244
- words = [word .lower () for word in words ]
245
- return [e2 for e1 in words for e2 in self .edit_distance_1 (e1 )]
252
+ words = [ENSURE_UNICODE (w ) for w in words ]
253
+ tmp = [
254
+ w if self ._case_sensitive else w .lower ()
255
+ for w in words
256
+ if self ._check_if_should_check (w )
257
+ ]
258
+ return [e2 for e1 in tmp for e2 in self .edit_distance_1 (e1 )]
246
259
247
260
@staticmethod
248
261
def _check_if_should_check (word ):
@@ -283,11 +296,13 @@ def __init__(self, tokenizer=None, case_sensitive=False):
283
296
284
297
def __contains__ (self , key ):
285
298
""" turn on contains """
299
+ key = ENSURE_UNICODE (key )
286
300
key = key if self ._case_sensitive else key .lower ()
287
301
return key in self ._dictionary
288
302
289
303
def __getitem__ (self , key ):
290
304
""" turn on getitem """
305
+ key = ENSURE_UNICODE (key )
291
306
key = key if self ._case_sensitive else key .lower ()
292
307
return self ._dictionary [key ]
293
308
@@ -298,6 +313,7 @@ def pop(self, key, default=None):
298
313
Args:
299
314
key (str): The key to remove
300
315
default (obj): The value to return if key is not present """
316
+ key = ENSURE_UNICODE (key )
301
317
key = key if self ._case_sensitive else key .lower ()
302
318
return self ._dictionary .pop (key , default )
303
319
@@ -344,6 +360,7 @@ def tokenize(self, text):
344
360
str: The next `word` in the tokenized string
345
361
Note:
346
362
This is the same as the `spellchecker.split_words()` """
363
+ text = ENSURE_UNICODE (text )
347
364
for word in self ._tokenizer (text ):
348
365
yield word if self ._case_sensitive else word .lower ()
349
366
@@ -408,6 +425,7 @@ def load_text(self, text, tokenizer=None):
408
425
text (str): The text to be loaded
409
426
tokenizer (function): The function to use to tokenize a string
410
427
"""
428
+ text = ENSURE_UNICODE (text )
411
429
if tokenizer :
412
430
words = [x if self ._case_sensitive else x .lower () for x in tokenizer (text )]
413
431
else :
@@ -421,6 +439,7 @@ def load_words(self, words):
421
439
422
440
Args:
423
441
words (list): The list of words to be loaded """
442
+ words = [ENSURE_UNICODE (w ) for w in words ]
424
443
self ._dictionary .update (
425
444
[word if self ._case_sensitive else word .lower () for word in words ]
426
445
)
@@ -431,13 +450,15 @@ def add(self, word):
431
450
432
451
Args:
433
452
word (str): The word to add """
453
+ word = ENSURE_UNICODE (word )
434
454
self .load_words ([word ])
435
455
436
456
def remove_words (self , words ):
437
457
""" Remove a list of words from the word frequency list
438
458
439
459
Args:
440
460
words (list): The list of words to remove """
461
+ words = [ENSURE_UNICODE (w ) for w in words ]
441
462
for word in words :
442
463
self ._dictionary .pop (word if self ._case_sensitive else word .lower ())
443
464
self ._update_dictionary ()
@@ -447,6 +468,7 @@ def remove(self, word):
447
468
448
469
Args:
449
470
word (str): The word to remove """
471
+ word = ENSURE_UNICODE (word )
450
472
self ._dictionary .pop (word if self ._case_sensitive else word .lower ())
451
473
self ._update_dictionary ()
452
474
0 commit comments