17
17
json_string , fitted_required , replace_entities_with_placeholders ,
18
18
check_persisted_path )
19
19
from snips_nlu .constants import (
20
- DATA , END , ENTITY , ENTITY_KIND , LANGUAGE , NGRAM , RES_MATCH_RANGE ,
21
- RES_VALUE , START , TEXT , ENTITIES )
20
+ DATA , ENTITY , ENTITY_KIND , LANGUAGE , NGRAM , TEXT , ENTITIES )
22
21
from snips_nlu .dataset import get_text_from_chunks , validate_and_format_dataset
23
22
from snips_nlu .entity_parser .builtin_entity_parser import (
24
23
is_builtin_entity )
@@ -264,7 +263,7 @@ def fit(self, x, dataset):
264
263
self ._init_vectorizer (self ._language )
265
264
self .builtin_entity_scope = set (
266
265
e for e in dataset [ENTITIES ] if is_builtin_entity (e ))
267
- preprocessed_data = self ._preprocess (x , training = True )
266
+ preprocessed_data = self ._preprocess (x )
268
267
utterances = [
269
268
self ._enrich_utterance (u , builtin_ents , custom_ents , w_clusters )
270
269
for u , builtin_ents , custom_ents , w_clusters
@@ -296,7 +295,7 @@ def fit_transform(self, x, dataset):
296
295
self ._init_vectorizer (self ._language )
297
296
self .builtin_entity_scope = set (
298
297
e for e in dataset [ENTITIES ] if is_builtin_entity (e ))
299
- preprocessed_data = self ._preprocess (x , training = True )
298
+ preprocessed_data = self ._preprocess (x )
300
299
utterances = [
301
300
self ._enrich_utterance (u , builtin_ents , custom_ents , w_clusters )
302
301
for u , builtin_ents , custom_ents , w_clusters
@@ -330,31 +329,30 @@ def transform(self, x):
330
329
for data in zip (* self ._preprocess (x ))]
331
330
return self ._tfidf_vectorizer .transform (utterances )
332
331
333
- def _preprocess (self , utterances , training = False ):
332
+ def _preprocess (self , utterances ):
334
333
normalized_utterances = deepcopy (utterances )
335
334
for u in normalized_utterances :
336
- for chunk in u [DATA ]:
335
+ nb_chunks = len (u [DATA ])
336
+ for i , chunk in enumerate (u [DATA ]):
337
337
chunk [TEXT ] = _normalize_stem (
338
338
chunk [TEXT ], self .language , self .resources ,
339
339
self .config .use_stemming )
340
-
341
- if training :
342
- builtin_ents , custom_ents = zip (
343
- * [_entities_from_utterance (u ) for u in utterances ])
344
- else :
345
- # Extract builtin entities on unormalized utterances
346
- builtin_ents = [
347
- self .builtin_entity_parser .parse (
348
- get_text_from_chunks (u [DATA ]),
349
- self .builtin_entity_scope , use_cache = True )
350
- for u in utterances
351
- ]
352
- # Extract builtin entities on normalized utterances
353
- custom_ents = [
354
- self .custom_entity_parser .parse (
355
- get_text_from_chunks (u [DATA ]), use_cache = True )
356
- for u in normalized_utterances
357
- ]
340
+ if i < nb_chunks - 1 :
341
+ chunk [TEXT ] += " "
342
+
343
+ # Extract builtin entities on unormalized utterances
344
+ builtin_ents = [
345
+ self .builtin_entity_parser .parse (
346
+ get_text_from_chunks (u [DATA ]),
347
+ self .builtin_entity_scope , use_cache = True )
348
+ for u in utterances
349
+ ]
350
+ # Extract builtin entities on normalized utterances
351
+ custom_ents = [
352
+ self .custom_entity_parser .parse (
353
+ get_text_from_chunks (u [DATA ]), use_cache = True )
354
+ for u in normalized_utterances
355
+ ]
358
356
if self .config .word_clusters_name :
359
357
# Extract world clusters on unormalized utterances
360
358
original_utterances_text = [get_text_from_chunks (u [DATA ])
@@ -582,7 +580,7 @@ def fit(self, x, dataset):
582
580
self .builtin_entity_scope = set (
583
581
e for e in dataset [ENTITIES ] if is_builtin_entity (e ))
584
582
585
- preprocessed = self ._preprocess (list (x ), training = True )
583
+ preprocessed = self ._preprocess (list (x ))
586
584
utterances = [
587
585
self ._enrich_utterance (utterance , builtin_ents , custom_ent )
588
586
for utterance , builtin_ents , custom_ent in zip (* preprocessed )]
@@ -648,7 +646,7 @@ def transform(self, x):
648
646
Raises:
649
647
NotTrained: when the vectorizer is not fitted
650
648
"""
651
- preprocessed = self ._preprocess (x , training = False )
649
+ preprocessed = self ._preprocess (x )
652
650
utterances = [
653
651
self ._enrich_utterance (utterance , builtin_ents , custom_ent )
654
652
for utterance , builtin_ents , custom_ent in zip (* preprocessed )]
@@ -661,24 +659,20 @@ def transform(self, x):
661
659
662
660
return x_coo .tocsr ()
663
661
664
- def _preprocess (self , x , training = False ):
665
- if training :
666
- builtin_ents , custom_ents = zip (
667
- * [_entities_from_utterance (u ) for u in x ])
668
- else :
669
- # Extract all entities on unnormalized data
670
- builtin_ents = [
671
- self .builtin_entity_parser .parse (
672
- get_text_from_chunks (u [DATA ]),
673
- self .builtin_entity_scope ,
674
- use_cache = True
675
- ) for u in x
676
- ]
677
- custom_ents = [
678
- self .custom_entity_parser .parse (
679
- get_text_from_chunks (u [DATA ]), use_cache = True )
680
- for u in x
681
- ]
662
+ def _preprocess (self , x ):
663
+ # Extract all entities on unnormalized data
664
+ builtin_ents = [
665
+ self .builtin_entity_parser .parse (
666
+ get_text_from_chunks (u [DATA ]),
667
+ self .builtin_entity_scope ,
668
+ use_cache = True
669
+ ) for u in x
670
+ ]
671
+ custom_ents = [
672
+ self .custom_entity_parser .parse (
673
+ get_text_from_chunks (u [DATA ]), use_cache = True )
674
+ for u in x
675
+ ]
682
676
return x , builtin_ents , custom_ents
683
677
684
678
def _extract_word_pairs (self , utterance ):
@@ -805,27 +799,3 @@ def _get_word_cluster_features(query_tokens, clusters_name, resources):
805
799
if cluster is not None :
806
800
cluster_features .append (cluster )
807
801
return cluster_features
808
-
809
-
810
- def _entities_from_utterance (utterance ):
811
- builtin_ents = []
812
- custom_ents = []
813
- current_ix = 0
814
- for chunk in utterance [DATA ]:
815
- text = chunk [TEXT ]
816
- text_length = len (text )
817
- if ENTITY in chunk :
818
- ent = {
819
- ENTITY_KIND : chunk [ENTITY ],
820
- RES_VALUE : text ,
821
- RES_MATCH_RANGE : {
822
- START : current_ix ,
823
- END : current_ix + text_length
824
- }
825
- }
826
- if is_builtin_entity (ent [ENTITY_KIND ]):
827
- builtin_ents .append (ent )
828
- else :
829
- custom_ents .append (ent )
830
- current_ix += text_length
831
- return builtin_ents , custom_ents
0 commit comments