diff --git a/ernie/__init__.py b/ernie/__init__.py index 3344544..a484e73 100644 --- a/ernie/__init__.py +++ b/ernie/__init__.py @@ -5,7 +5,7 @@ from tensorflow.python.client import device_lib import logging -__version__ = '1.2307.0' +__version__ = '1.2308.0' logging.getLogger().setLevel(logging.WARNING) logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) diff --git a/ernie/ernie.py b/ernie/ernie.py index e4a7923..eb801db 100644 --- a/ernie/ernie.py +++ b/ernie/ernie.py @@ -266,9 +266,17 @@ def _predict_batch( add_special_tokens=True, max_length=self._tokenizer.model_max_length, ) + + input_ids = features['input_ids'] + if 'token_type_ids' in features: + token_type_ids = features['token_type_ids'] + else: + token_type_ids = [0] * len(input_ids) # fill with zeros + input_ids, _, attention_mask = ( - features['input_ids'], features['token_type_ids'], - features['attention_mask'] + input_ids, + token_type_ids, + features['attention_mask'], ) input_ids = self._list_to_padded_array(features['input_ids']) diff --git a/ernie/helper.py b/ernie/helper.py index 754aaf9..52b9e06 100644 --- a/ernie/helper.py +++ b/ernie/helper.py @@ -15,17 +15,22 @@ def get_features(tokenizer, sentences, labels): add_special_tokens=True, max_length=tokenizer.model_max_length, ) - input_ids, token_type_ids = ( - inputs['input_ids'], - inputs['token_type_ids'], - ) + + input_ids = inputs['input_ids'] + if 'token_type_ids' in inputs: + token_type_ids = inputs['token_type_ids'] + else: + token_type_ids = [0] * len(input_ids) # fill with zeros + padding_length = tokenizer.model_max_length - len(input_ids) if tokenizer.padding_side == 'right': attention_mask = [1] * len(input_ids) + [0] * padding_length input_ids = input_ids + [tokenizer.pad_token_id] * padding_length - token_type_ids = token_type_ids + \ - [tokenizer.pad_token_type_id] * padding_length + token_type_ids = ( + token_type_ids + [tokenizer.pad_token_type_id] * padding_length + ) + else: attention_mask = [0] * padding_length + [1] * len(input_ids) input_ids = [tokenizer.pad_token_id] * padding_length + input_ids