Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions torchtitan/components/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tokenizers import AddedToken, Tokenizer
from torchtitan.config import JobConfig
from torchtitan.tools.logging import logger
from transformers import AutoTokenizer
from typing_extensions import override


Expand Down Expand Up @@ -59,6 +60,8 @@ def __init__(
self.eos_id = None
self.bos_token = None
self.eos_token = None
self.pad_id = None # only used for SFT
self.pad_token = None # only used for SFT

# Load the underlying tokenizer
self.tokenizer = self._load_tokenizer_from_path(tokenizer_path)
Expand All @@ -68,6 +71,10 @@ def __init__(
os.path.join(tokenizer_path, "tokenizer_config.json")
)

self.backup_hf_tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_path, use_fast=True
)

# Infer special tokens and adding BOS/EOS behavior
self._infer_special_tokens()
self._infer_should_add_bos_eos()
Expand Down Expand Up @@ -214,7 +221,11 @@ def _process_special_token(
if self.config
else None
)

config_pad_token = (
self._get_token_from_config(self.config, "pad_token")
if self.config
else None
)
# Store BOS/EOS tokens as class attributes if they match
if token_str == config_bos_token:
self.bos_token = token_str
Expand All @@ -230,7 +241,13 @@ def _process_special_token(
if token_id is not None
else self.tokenizer.token_to_id(token_str)
)

elif token_str == config_pad_token:
self.pad_token = token_str
self.pad_id = (
token_id
if token_id is not None
else self.tokenizer.token_to_id(token_str)
)
# Create AddedToken object based on config format
if isinstance(token_config, dict):
if token_config.get("__type") == "AddedToken" or "content" in token_config:
Expand Down Expand Up @@ -301,6 +318,12 @@ def _infer_special_tokens(self):
self.bos_id = self.tokenizer.token_to_id(self.bos_token)
if self.eos_token:
self.eos_id = self.tokenizer.token_to_id(self.eos_token)
# FOR SFT, update pad token id if pad token is not in the tokenizer
if self.pad_token is not None:
self.pad_id = self.tokenizer.token_to_id(self.pad_token)
else:
self.pad_id = self.eos_id
self.pad_token = self.eos_token

def _infer_should_add_bos_eos(self):
"""
Expand Down Expand Up @@ -409,6 +432,10 @@ def id_to_token(self, token_id: int) -> Optional[str]:
"""Convert ID to token."""
return self.tokenizer.id_to_token(token_id)

def apply_chat_template(self, conversation, **kwargs):

return self.backup_hf_tokenizer.apply_chat_template(conversation, **kwargs)


def build_hf_tokenizer(
job_config: JobConfig,
Expand Down
26 changes: 26 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,36 @@ class LRScheduler:
"""


@dataclass
class SFTDataConfig:
split: str = "train"
dataset_subset: str | None = None
messages_key: str = "messages"
prompt_key: str = "question"
response_key: str = "answer"
tools_key: str = "tools"
thinking_key: str = "enable_thinking"
always_thinking: bool = False
is_multiturn: bool = False
apply_chat_template: bool = False
greedy_packing: bool = True
pad_mode: Literal["right", "no_padding"] = "right"
truncation: Literal["left", "right", "error"] = "right"
chat_template_kwargs: dict = field(default_factory=dict)
ignore_input_ids_mismatch: bool = False


@dataclass
class Training:
dataset: str = "c4_test"
"""Dataset to use"""

running_sft_training: bool = False
"""
If True, we are running SFT training. And we will overwrite the dataset config
from the sft_data_config.
"""

dataset_path: str | None = None
"""
Path to the dataset in the file system. If provided, data will be
Expand Down Expand Up @@ -937,6 +962,7 @@ class JobConfig:
profiling: Profiling = field(default_factory=Profiling)
metrics: Metrics = field(default_factory=Metrics)
model: Model = field(default_factory=Model)
sft_data_config: SFTDataConfig = field(default_factory=SFTDataConfig)
optimizer: Optimizer = field(default_factory=Optimizer)
lr_scheduler: LRScheduler = field(default_factory=LRScheduler)
training: Training = field(default_factory=Training)
Expand Down
Loading