diff --git a/vidur/config/device_sku_config.py b/vidur/config/device_sku_config.py index a92646f..8ac9bf5 100644 --- a/vidur/config/device_sku_config.py +++ b/vidur/config/device_sku_config.py @@ -14,9 +14,9 @@ class BaseDeviceSKUConfig(BaseFixedConfig): @dataclass -class A100DeviceSKUConfig(BaseDeviceSKUConfig): - fp16_tflops: int = 312 - total_memory_gb: int = 80 +class A40DeviceSKUConfig(BaseDeviceSKUConfig): + fp16_tflops: int = 150 + total_memory_gb: int = 45 @staticmethod def get_type(): @@ -24,9 +24,9 @@ def get_type(): @dataclass -class A40DeviceSKUConfig(BaseDeviceSKUConfig): - fp16_tflops: int = 150 - total_memory_gb: int = 45 +class A100DeviceSKUConfig(BaseDeviceSKUConfig): + fp16_tflops: int = 312 + total_memory_gb: int = 80 @staticmethod def get_type(): diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index 0057d78..722299b 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -24,7 +24,7 @@ class BaseModelConfig(BaseFixedConfig): post_attn_norm: bool vocab_size: int is_neox_style: Optional[bool] = True - rope_theta: Optional[int] = None + rope_theta: Optional[float] = None rope_scaling: Optional[Dict[str, Any]] = None partial_rotary_factor: float = 1.0 no_tensor_parallel: bool = False @@ -41,7 +41,7 @@ class Llama2ModelConfig(BaseModelConfig): post_attn_norm: bool = True vocab_size: int = 32768 is_neox_style: Optional[bool] = True - rope_theta: Optional[int] = 10000.0 + rope_theta: Optional[float] = 10000 rope_scaling: Optional[Dict[str, Any]] = None partial_rotary_factor: float = 1.0 no_tensor_parallel: bool = False @@ -58,6 +58,7 @@ class CodeLlama34BModelConfig(Llama2ModelConfig): num_kv_heads: int = 8 embedding_dim: int = 8192 mlp_hidden_dim: int = 22016 + rope_theta: Optional[float] = 1000000 @staticmethod def get_name(): @@ -71,6 +72,7 @@ class Llama2_7BModelConfig(Llama2ModelConfig): num_kv_heads: int = 32 embedding_dim: int = 4096 mlp_hidden_dim: int = 11008 + max_position_embeddings: int = 4096 @staticmethod def get_name(): @@ -84,6 +86,7 @@ class Llama2_70BModelConfig(Llama2ModelConfig): num_kv_heads: int = 8 embedding_dim: int = 8192 mlp_hidden_dim: int = 28672 + max_position_embeddings: int = 4096 @staticmethod def get_name(): @@ -98,7 +101,7 @@ class Llama3_8BModelConfig(Llama2ModelConfig): embedding_dim: int = 4096 mlp_hidden_dim: int = 14336 max_position_embeddings: int = 4096 - rope_theta: Optional[int] = 500000.0 + rope_theta: Optional[float] = 500000 vocab_size: int = 128256 @staticmethod @@ -114,7 +117,7 @@ class Llama3_70BModelConfig(Llama2ModelConfig): embedding_dim: int = 8192 mlp_hidden_dim: int = 28672 max_position_embeddings: int = 8192 - rope_theta: Optional[int] = 500000.0 + rope_theta: Optional[float] = 500000 vocab_size: int = 128256 @staticmethod @@ -122,6 +125,25 @@ def get_name(): return "meta-llama/Meta-Llama-3-70B" +@dataclass +class InternLMModelConfig(Llama2ModelConfig): + max_position_embeddings: int = 4096 + vocab_size: int = 103168 + + +@dataclass +class InternLM_20BModelConfig(InternLMModelConfig): + num_layers: int = 60 + num_q_heads: int = 40 + num_kv_heads: int = 40 + embedding_dim: int = 5120 + mlp_hidden_dim: int = 13824 + + @staticmethod + def get_name(): + return "internlm/internlm-20b" + + @dataclass class InternLM2ModelConfig(Llama2ModelConfig): max_position_embeddings: int = 32768 @@ -135,6 +157,7 @@ class InternLM2_20BModelConfig(InternLM2ModelConfig): num_kv_heads: int = 8 embedding_dim: int = 6144 mlp_hidden_dim: int = 16384 + rope_theta: Optional[float] = 1000000 @staticmethod def get_name(): @@ -157,10 +180,9 @@ class Phi2ModelConfig(Llama2ModelConfig): post_attn_norm: bool = False vocab_size: int = 51200 rope_scaling: Optional[Dict[str, Any]] = None - rope_theta: Optional[int] = 10000.0 + rope_theta: Optional[float] = 10000 partial_rotary_factor: float = 0.4 no_tensor_parallel: bool = True - is_neox_style: bool = True @staticmethod def get_name(): @@ -185,6 +207,7 @@ class Qwen72BModelConfig(QwenModelConfig): num_kv_heads: int = 64 embedding_dim: int = 8192 mlp_hidden_dim: int = 24576 + rope_theta: Optional[float] = 1000000 @staticmethod def get_name(): diff --git a/vidur/config/node_sku_config.py b/vidur/config/node_sku_config.py index 34eb805..ce2271f 100644 --- a/vidur/config/node_sku_config.py +++ b/vidur/config/node_sku_config.py @@ -25,7 +25,7 @@ def get_type(): @dataclass class A100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): device_sku_type: DeviceSKUType = DeviceSKUType.A100 - num_devices_per_node: int = 8 + num_devices_per_node: int = 4 @staticmethod def get_type(): @@ -35,7 +35,7 @@ def get_type(): @dataclass class H100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): device_sku_type: DeviceSKUType = DeviceSKUType.H100 - num_devices_per_node: int = 8 + num_devices_per_node: int = 4 @staticmethod def get_type():