Skip to content

Commit

Permalink
Fulfill token_type_ids with zeroes
Browse files Browse the repository at this point in the history
  • Loading branch information
brunneis committed Aug 2, 2023
1 parent 5aef5ad commit ad93781
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ernie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tensorflow.python.client import device_lib
import logging

__version__ = '1.2307.0'
__version__ = '1.2308.0'

logging.getLogger().setLevel(logging.WARNING)
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
Expand Down
12 changes: 10 additions & 2 deletions ernie/ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,17 @@ def _predict_batch(
add_special_tokens=True,
max_length=self._tokenizer.model_max_length,
)

input_ids = features['input_ids']
if 'token_type_ids' in features:
token_type_ids = features['token_type_ids']
else:
token_type_ids = [0] * len(input_ids) # fill with zeros

input_ids, _, attention_mask = (
features['input_ids'], features['token_type_ids'],
features['attention_mask']
input_ids,
token_type_ids,
features['attention_mask'],
)

input_ids = self._list_to_padded_array(features['input_ids'])
Expand Down
17 changes: 11 additions & 6 deletions ernie/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@ def get_features(tokenizer, sentences, labels):
add_special_tokens=True,
max_length=tokenizer.model_max_length,
)
input_ids, token_type_ids = (
inputs['input_ids'],
inputs['token_type_ids'],
)

input_ids = inputs['input_ids']
if 'token_type_ids' in inputs:
token_type_ids = inputs['token_type_ids']
else:
token_type_ids = [0] * len(input_ids) # fill with zeros

padding_length = tokenizer.model_max_length - len(input_ids)

if tokenizer.padding_side == 'right':
attention_mask = [1] * len(input_ids) + [0] * padding_length
input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
token_type_ids = token_type_ids + \
[tokenizer.pad_token_type_id] * padding_length
token_type_ids = (
token_type_ids + [tokenizer.pad_token_type_id] * padding_length
)

else:
attention_mask = [0] * padding_length + [1] * len(input_ids)
input_ids = [tokenizer.pad_token_id] * padding_length + input_ids
Expand Down

0 comments on commit ad93781

Please sign in to comment.