Skip to content

Commit

Permalink
[Bugfix] Update config classes to match old configurations (#26)
Browse files Browse the repository at this point in the history
* Config updates for device, model and node

* Optional[float] for rope theta

---------

Co-authored-by: Amey Agrawal <[email protected]>
  • Loading branch information
anmolagarwalcp810 and AgrawalAmey authored Aug 1, 2024
1 parent 6aa32c6 commit 86e0a99
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 14 deletions.
12 changes: 6 additions & 6 deletions vidur/config/device_sku_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ class BaseDeviceSKUConfig(BaseFixedConfig):


@dataclass
class A100DeviceSKUConfig(BaseDeviceSKUConfig):
fp16_tflops: int = 312
total_memory_gb: int = 80
class A40DeviceSKUConfig(BaseDeviceSKUConfig):
fp16_tflops: int = 150
total_memory_gb: int = 45

@staticmethod
def get_type():
return DeviceSKUType.A40


@dataclass
class A40DeviceSKUConfig(BaseDeviceSKUConfig):
fp16_tflops: int = 150
total_memory_gb: int = 45
class A100DeviceSKUConfig(BaseDeviceSKUConfig):
fp16_tflops: int = 312
total_memory_gb: int = 80

@staticmethod
def get_type():
Expand Down
35 changes: 29 additions & 6 deletions vidur/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BaseModelConfig(BaseFixedConfig):
post_attn_norm: bool
vocab_size: int
is_neox_style: Optional[bool] = True
rope_theta: Optional[int] = None
rope_theta: Optional[float] = None
rope_scaling: Optional[Dict[str, Any]] = None
partial_rotary_factor: float = 1.0
no_tensor_parallel: bool = False
Expand All @@ -41,7 +41,7 @@ class Llama2ModelConfig(BaseModelConfig):
post_attn_norm: bool = True
vocab_size: int = 32768
is_neox_style: Optional[bool] = True
rope_theta: Optional[int] = 10000.0
rope_theta: Optional[float] = 10000
rope_scaling: Optional[Dict[str, Any]] = None
partial_rotary_factor: float = 1.0
no_tensor_parallel: bool = False
Expand All @@ -58,6 +58,7 @@ class CodeLlama34BModelConfig(Llama2ModelConfig):
num_kv_heads: int = 8
embedding_dim: int = 8192
mlp_hidden_dim: int = 22016
rope_theta: Optional[float] = 1000000

@staticmethod
def get_name():
Expand All @@ -71,6 +72,7 @@ class Llama2_7BModelConfig(Llama2ModelConfig):
num_kv_heads: int = 32
embedding_dim: int = 4096
mlp_hidden_dim: int = 11008
max_position_embeddings: int = 4096

@staticmethod
def get_name():
Expand All @@ -84,6 +86,7 @@ class Llama2_70BModelConfig(Llama2ModelConfig):
num_kv_heads: int = 8
embedding_dim: int = 8192
mlp_hidden_dim: int = 28672
max_position_embeddings: int = 4096

@staticmethod
def get_name():
Expand All @@ -98,7 +101,7 @@ class Llama3_8BModelConfig(Llama2ModelConfig):
embedding_dim: int = 4096
mlp_hidden_dim: int = 14336
max_position_embeddings: int = 4096
rope_theta: Optional[int] = 500000.0
rope_theta: Optional[float] = 500000
vocab_size: int = 128256

@staticmethod
Expand All @@ -114,14 +117,33 @@ class Llama3_70BModelConfig(Llama2ModelConfig):
embedding_dim: int = 8192
mlp_hidden_dim: int = 28672
max_position_embeddings: int = 8192
rope_theta: Optional[int] = 500000.0
rope_theta: Optional[float] = 500000
vocab_size: int = 128256

@staticmethod
def get_name():
return "meta-llama/Meta-Llama-3-70B"


@dataclass
class InternLMModelConfig(Llama2ModelConfig):
max_position_embeddings: int = 4096
vocab_size: int = 103168


@dataclass
class InternLM_20BModelConfig(InternLMModelConfig):
num_layers: int = 60
num_q_heads: int = 40
num_kv_heads: int = 40
embedding_dim: int = 5120
mlp_hidden_dim: int = 13824

@staticmethod
def get_name():
return "internlm/internlm-20b"


@dataclass
class InternLM2ModelConfig(Llama2ModelConfig):
max_position_embeddings: int = 32768
Expand All @@ -135,6 +157,7 @@ class InternLM2_20BModelConfig(InternLM2ModelConfig):
num_kv_heads: int = 8
embedding_dim: int = 6144
mlp_hidden_dim: int = 16384
rope_theta: Optional[float] = 1000000

@staticmethod
def get_name():
Expand All @@ -157,10 +180,9 @@ class Phi2ModelConfig(Llama2ModelConfig):
post_attn_norm: bool = False
vocab_size: int = 51200
rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[int] = 10000.0
rope_theta: Optional[float] = 10000
partial_rotary_factor: float = 0.4
no_tensor_parallel: bool = True
is_neox_style: bool = True

@staticmethod
def get_name():
Expand All @@ -185,6 +207,7 @@ class Qwen72BModelConfig(QwenModelConfig):
num_kv_heads: int = 64
embedding_dim: int = 8192
mlp_hidden_dim: int = 24576
rope_theta: Optional[float] = 1000000

@staticmethod
def get_name():
Expand Down
4 changes: 2 additions & 2 deletions vidur/config/node_sku_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_type():
@dataclass
class A100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig):
device_sku_type: DeviceSKUType = DeviceSKUType.A100
num_devices_per_node: int = 8
num_devices_per_node: int = 4

@staticmethod
def get_type():
Expand All @@ -35,7 +35,7 @@ def get_type():
@dataclass
class H100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig):
device_sku_type: DeviceSKUType = DeviceSKUType.H100
num_devices_per_node: int = 8
num_devices_per_node: int = 4

@staticmethod
def get_type():
Expand Down

0 comments on commit 86e0a99

Please sign in to comment.