Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Adding 1.58bit LLMs training architecture in nanotron #180

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions examples/config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DatasetStageArgs,
GeneralArgs,
LlamaConfig,
LlamaBitNetConfig,
LoggingArgs,
LRSchedulerArgs,
ModelArgs,
Expand All @@ -21,14 +22,35 @@
)
from nanotron.logging import human_format

model_config = LlamaConfig(
# Config for a tiny model model with 1.62M parameters
# model_config = LlamaConfig(
# # Config for a tiny model model with 1.62M parameters
# bos_token_id=1,
# eos_token_id=2,
# hidden_act="silu",
# hidden_size=16,
# initializer_range=0.02,
# intermediate_size=64,
# max_position_embeddings=256,
# num_attention_heads=4,
# num_hidden_layers=2,
# num_key_value_heads=4,
# pretraining_tp=1,
# rms_norm_eps=1e-05,
# rope_scaling=None,
# tie_word_embeddings=True,
# use_cache=True,
# vocab_size=256,
# )

model_config = LlamaBitNetConfig(
# Config for a tiny 1.58bit model model with 1.62M parameters
bos_token_id=1,
eos_token_id=2,
hidden_act="silu",
hidden_size=16,
initializer_range=0.02,
intermediate_size=64,
is_bitnet_config=True,
max_position_embeddings=256,
num_attention_heads=4,
num_hidden_layers=2,
Expand Down
2 changes: 1 addition & 1 deletion examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ model:
hidden_size: 16
initializer_range: 0.02
intermediate_size: 64
is_llama_config: true
is_bitnet_config: true
max_position_embeddings: 256
num_attention_heads: 4
num_hidden_layers: 2
Expand Down
1 change: 0 additions & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,5 @@ def get_args():
# Load trainer and data
trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer)

# Train
trainer.train(dataloader)
1 change: 0 additions & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def __post_init__(self):
self.dtype = torch.bfloat16
if isinstance(self.dtype, str):
self.dtype = cast_str_to_torch_dtype(self.dtype)

self.model_config._is_using_mup = isinstance(self.init_method, SpectralMupInit)

# if self.model_config.max_position_embeddings is None:
Expand Down
41 changes: 40 additions & 1 deletion src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,45 @@ def __post_init__(self):
def is_using_mup(self) -> bool:
return self._is_using_mup

@dataclass
class LlamaBitNetConfig:
"""Configuration for a LLAMA model

Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""

bos_token_id: int = 1
eos_token_id: int = 2
hidden_act: str = "silu"
hidden_size: int = 4096
initializer_range: float = 0.02
intermediate_size: int = 11008
is_bitnet_config: bool = True # We use this help differentiate models in yaml/python conversion
max_position_embeddings: int = 2048
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: Optional[int] = None
pad_token_id: Optional[int] = None
pretraining_tp: int = 1
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000

def __post_init__(self):
# NOTE: user don't set self._init_method, ModelArgs will set it
# then we only pass LlamaConfig around
self._is_using_mup: bool = False
# self._init_method: Optional[Union[RandomInit, SpectralMupInit, ExistingCheckpointInit]] = None

# for backward compatibility
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads

@property
def is_using_mup(self) -> bool:
return self._is_using_mup

@dataclass
class Starcoder2Config:
Expand Down Expand Up @@ -132,4 +171,4 @@ def n_inner(self):
return self.intermediate_size


NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any]
NanotronConfigs = Union[LlamaConfig, Starcoder2Config, LlamaBitNetConfig, Any]
Loading
Loading