Skip to content

Commit

Permalink
Fix for missing EOS token (#408)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Feb 5, 2024
1 parent add396d commit 5cc9012
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mii/modeling/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def decode(self, tokens: torch.Tensor) -> str:
class HFTokenizer(MIITokenizerWrapper):
def __init__(self, tokenizer: Union[str, object]) -> None:
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
super().__init__(tokenizer)

Expand All @@ -51,7 +51,11 @@ def vocab_size(self) -> int:

@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_token_id
eos_token_attrs = ["eod", "eos_token_id", "eos_token", "eod_id"]
for attr in eos_token_attrs:
if getattr(self.tokenizer, attr, None) is not None:
return getattr(self.tokenizer, attr)
raise ValueError(f"Tokenizer must have one of {eos_token_attrs} attributes.")

def encode(self, input: str) -> torch.Tensor:
return self.tokenizer.encode(input, return_tensors="pt").flatten()
Expand Down

0 comments on commit 5cc9012

Please sign in to comment.