From a1533fc0cdf5156928d57177ae5a8f3990eb30f0 Mon Sep 17 00:00:00 2001 From: Bobby Chen Date: Mon, 25 Nov 2024 16:17:23 -0800 Subject: [PATCH 1/8] Support mamba_hybrid export --- megatron/core/export/model_type.py | 2 +- .../default_conversion_dict.py | 2 + .../mamba_hybrid_model.py | 57 +++++++++++++++++++ .../core/export/trtllm/trt_model_config.py | 1 + megatron/core/export/trtllm/trt_model_type.py | 1 + megatron/core/export/trtllm/trtllm_helper.py | 16 ++++++ megatron/core/export/trtllm/trtllm_layers.py | 10 ++++ ...e_device_trtllm_model_weights_converter.py | 37 +++++++++++- 8 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py diff --git a/megatron/core/export/model_type.py b/megatron/core/export/model_type.py index 6a33d6440e..32f53f56d2 100644 --- a/megatron/core/export/model_type.py +++ b/megatron/core/export/model_type.py @@ -3,5 +3,5 @@ from enum import Enum ModelType = Enum( - 'ModelType', ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma"] + 'ModelType', ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma", "mamba_hybrid"] ) diff --git a/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py b/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py index cad9315034..d18fabaabd 100644 --- a/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +++ b/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py @@ -7,6 +7,7 @@ from megatron.core.export.trtllm.model_to_trllm_mapping.gpt_next_model import GPT_NEXT_DICT from megatron.core.export.trtllm.model_to_trllm_mapping.llama_model import LLAMA_DICT from megatron.core.export.trtllm.model_to_trllm_mapping.starcoder_model import STARCODER_DICT +from megatron.core.export.trtllm.model_to_trllm_mapping.mamba_hybrid_model import MAMBA_HYBRID_DICT DEFAULT_CONVERSION_DICT = { ModelType.llama: LLAMA_DICT, @@ -15,4 +16,5 @@ ModelType.starcoder: STARCODER_DICT, ModelType.gpt: GPT_DICT, ModelType.gptnext: GPT_NEXT_DICT, + ModelType.mamba_hybrid: MAMBA_HYBRID_DICT, } diff --git a/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py b/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py new file mode 100644 index 0000000000..e3bbd028f2 --- /dev/null +++ b/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import torch +from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers + +def mamba_preprocess_weight(model_state_dict: dict): + for k in list(model_state_dict.keys()): + if 'mixer.in_proj.weight' in k: + if k[-1] == 'z': + prefix = k[:-2] + z = model_state_dict.pop(k) + x = model_state_dict.pop(f"{prefix}.x") + B = model_state_dict.pop(f"{prefix}.B") + C = model_state_dict.pop(f"{prefix}.C") + dt = model_state_dict.pop(f"{prefix}.dt") + model_state_dict[prefix] = torch.concatenate( + [z, x, B, C, dt], dim=0 + ) + elif 'conv1d' in k: + if k[-1] == 'x': + prefix = k[:-2] + x = model_state_dict.pop(k) + B = model_state_dict.pop(f"{prefix}.B") + C = model_state_dict.pop(f"{prefix}.C") + model_state_dict[prefix] = torch.concatenate( + [x, B, C], dim=0 + ) + + +MAMBA_HYBRID_DICT = { + # INPUT + 'embedding.word_embeddings.weight': TRTLLMLayers.vocab_embedding, + # ATTENTION + 'decoder.layers.self_attention.linear_qkv.weight': TRTLLMLayers.attention_qkv_weight, + 'decoder.layers.self_attention.linear_qkv.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, + 'decoder.layers.self_attention.linear_proj.weight': TRTLLMLayers.attention_dense_weight, + # MLP + 'decoder.layers.mlp.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight, + 'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight, + 'decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, + # Mixer + 'decoder.layers.mixer.dt_bias': TRTLLMLayers.mixer_dt_bias, + 'decoder.layers.mixer.A_log': TRTLLMLayers.mixer_A_log, + 'decoder.layers.mixer.D': TRTLLMLayers.mixer_D, + 'decoder.layers.mixer.in_proj.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, + 'decoder.layers.mixer.in_proj.weight': TRTLLMLayers.mixer_in_proj_weight, + 'decoder.layers.mixer.conv1d.weight':TRTLLMLayers.mixer_conv_weight, + 'decoder.layers.mixer.conv1d.bias': TRTLLMLayers.mixer_conv_bias, + 'decoder.layers.mixer.out_proj.weight': TRTLLMLayers.mixer_out_proj_weight, + 'decoder.layers.mixer.norm.weight': TRTLLMLayers.mixer_norm_weight, + # FINAL LAYER NORM + 'decoder.final_norm.weight': TRTLLMLayers.final_layernorm_weight, + # OUTPUT LAYER + 'output_layer.weight': TRTLLMLayers.lm_head, + + 'preprocess_weight': mamba_preprocess_weight, +} diff --git a/megatron/core/export/trtllm/trt_model_config.py b/megatron/core/export/trtllm/trt_model_config.py index 2ed09398c2..d7bd16ea6e 100644 --- a/megatron/core/export/trtllm/trt_model_config.py +++ b/megatron/core/export/trtllm/trt_model_config.py @@ -12,4 +12,5 @@ ModelType.llama: tensorrt_llm.models.llama.config.LLaMAConfig, ModelType.gemma: tensorrt_llm.models.GemmaConfig, ModelType.falcon: tensorrt_llm.models.falcon.config.FalconConfig, + ModelType.mamba_hybrid: tensorrt_llm.models.mamba_hybrid.config.MambaHybridConfig, } diff --git a/megatron/core/export/trtllm/trt_model_type.py b/megatron/core/export/trtllm/trt_model_type.py index f45ff1786e..a1b220dc67 100644 --- a/megatron/core/export/trtllm/trt_model_type.py +++ b/megatron/core/export/trtllm/trt_model_type.py @@ -10,4 +10,5 @@ ModelType.llama: 'LlamaForCausalLM', ModelType.gemma: 'GemmaForCausalLM', ModelType.falcon: 'FalconForCausalLM', + ModelType.mamba_hybrid: 'MambaHybridForCausalLM', } diff --git a/megatron/core/export/trtllm/trtllm_helper.py b/megatron/core/export/trtllm/trtllm_helper.py index d8bef18b33..68cf027c15 100644 --- a/megatron/core/export/trtllm/trtllm_helper.py +++ b/megatron/core/export/trtllm/trtllm_helper.py @@ -42,6 +42,7 @@ def __init__( seq_len_interpolation_factor: float = None, moe_renorm_mode=None, share_embeddings_and_output_weights=False, + hybrid_override_pattern: str=None, ): """Constructor for the TRTLLMHelper @@ -72,6 +73,7 @@ def __init__( assert position_embedding_type in [ 'learned_absolute', 'rope', + 'none', ], f"Position embedding type should be one of learned_absolute, rope. You entered {position_embedding_type}" self.position_embedding_type = position_embedding_type self.max_position_embeddings = max_position_embeddings @@ -83,6 +85,7 @@ def __init__( self.seq_len_interpolation_factor = seq_len_interpolation_factor self.moe_renorm_mode = moe_renorm_mode self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.hybrid_override_pattern = hybrid_override_pattern def _get_trtllm_config( self, @@ -162,6 +165,19 @@ def _get_trtllm_config( False if self.transformer_config.num_layers == 32 else True ) config["parallel_attention"] = True + elif self.model_type == ModelType.mamba_hybrid: + config["mamba_version"] = "Mamba2" + config["rnn_hidden_size"] = 2 * self.transformer_config.hidden_size + config["state_size"] = 128 + config["conv_kernel"] = 4 + config["rnn_head_size"] = 64 + config["ngroups"] = 8 + config["chunk_size"] = 128 + config["rnn_conv_dim_size"] = (config["rnn_hidden_size"] + 2 * config["ngroups"] + * config["state_size"]) + config["use_bias"] = config["bias"] + config["hybrid_override_pattern"] = self.hybrid_override_pattern + config["ssm_rmsnorm"] = True if self.seq_len_interpolation_factor is not None: config["rotary_scaling"] = { diff --git a/megatron/core/export/trtllm/trtllm_layers.py b/megatron/core/export/trtllm/trtllm_layers.py index 0cf805dcb6..097055cf05 100644 --- a/megatron/core/export/trtllm/trtllm_layers.py +++ b/megatron/core/export/trtllm/trtllm_layers.py @@ -38,6 +38,16 @@ class TRTLLMLayers(Enum): mlp_projection_weight = 'transformer.layers.mlp.proj.weight' mlp_projection_bias = 'transformer.layers.mlp.proj.bias' + # mamba layers + mixer_dt_bias = 'transformer.layers.layer.dt_bias' + mixer_A_log = 'transformer.layers.layer.A' + mixer_D = 'transformer.layers.layer.D' + mixer_in_proj_weight = 'transformer.layers.layer.in_proj.weight' + mixer_out_proj_weight = 'transformer.layers.layer.out_proj.weight' + mixer_conv_weight = 'transformer.layers.layer.conv1d.weight' + mixer_conv_bias = 'transformer.layers.layer.conv1d.bias' + mixer_norm_weight = 'transformer.layers.layer.norm.weight' + # mixture of expert layers mlp_router_weight = 'transformer.layers.mlp.router.weight' mlp_fc_weight_mixture_of_experts = 'transformer.layers.mlp.fc.weight.expert' diff --git a/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py b/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py index c7a98972d2..06f9242b4e 100644 --- a/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +++ b/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py @@ -111,6 +111,11 @@ def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type= self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = ( split_val.to(self.storage_type).detach().contiguous() ) + elif split_type == 'conv': + val = val.unsqueeze(-1) + self.trtllm_model_weights[layer_name] = ( + val.to(self.storage_type).detach().contiguous() + ) else: if val.ndim >= 2: val = torch.transpose(val.reshape(val.shape[0], -1), 1, 0) @@ -130,6 +135,12 @@ def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type= or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias)) or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias)) or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_dt_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_D)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_in_proj_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_norm_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_out_proj_weight)) ): # Same as layernorm1p in NeMo if ( @@ -261,6 +272,13 @@ def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type= _add_to_trtllm_model_weights( val=split_vals, layer_name=layer_name, split_type='expert_split' ) + elif layer_name.endswith(suffix(TRTLLMLayers.mixer_A_log)): + val = -torch.exp(val.float()) + _add_to_trtllm_model_weights(val=val, layer_name=layer_name, split_type=None) + elif layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_weight)): + _add_to_trtllm_model_weights( + val=val, layer_name=layer_name, split_type='conv' + ) else: raise ValueError(f"{layer_name} cannot be handled by converter") @@ -278,6 +296,9 @@ def convert( state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True """ + if 'preprocess_weight' in trtllm_conversion_dict: + trtllm_conversion_dict['preprocess_weight'](model_state_dict) + # First step is to convert input model layer names to equivalent trtllm layer names model_state_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names( model_state_dict=model_state_dict, @@ -365,6 +386,7 @@ def _split(torch_tensor, tp_size, idx, dim=0): pp_layer_range = mapping.pp_layers(self.transformer_config.num_layers) + is_mamba = hasattr(trtllm_model_config, "mamba_version") trtllm_model_weights_per_gpu = {} for layer_name, value in self.trtllm_model_weights.items(): if layer_name in NON_TRANSFORMER_LAYERS_NAMES: @@ -391,6 +413,11 @@ def _split(torch_tensor, tp_size, idx, dim=0): ): layer_name = layer_name.replace("post_layernorm", "mlp_layernorm") + if is_mamba: + layer_name = layer_name.replace("transformer", "backbone") + layer_name = layer_name.replace("mlp", "layer") + layer_name = layer_name.replace("attention", "layer") + trtllm_model_weights_per_gpu[layer_name] = value if mapping.is_first_pp_rank(): @@ -404,7 +431,10 @@ def _split(torch_tensor, tp_size, idx, dim=0): else self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value] ) - trtllm_model_weights_per_gpu[TRTLLMLayers.vocab_embedding.value] = embedding_weight + key = TRTLLMLayers.vocab_embedding.value + if is_mamba: + key = key.replace("transformer", "backbone") + trtllm_model_weights_per_gpu[key] = embedding_weight pos_embedding_weight = self.trtllm_model_weights.get( TRTLLMLayers.position_embedding.value @@ -426,7 +456,10 @@ def _split(torch_tensor, tp_size, idx, dim=0): lm_head_weight, mapping.tp_size, mapping.tp_rank ) - trtllm_model_weights_per_gpu[TRTLLMLayers.final_layernorm_weight.value] = ( + key = TRTLLMLayers.final_layernorm_weight.value + if is_mamba: + key = key.replace("transformer", "backbone") + trtllm_model_weights_per_gpu[key] = ( self.trtllm_model_weights[TRTLLMLayers.final_layernorm_weight.value] ) From 6499b61c77ee5690a8c46c731315049f1c18d5d1 Mon Sep 17 00:00:00 2001 From: Bobby Chen Date: Tue, 26 Nov 2024 10:45:28 -0800 Subject: [PATCH 2/8] Fix key mapping --- .../mamba_hybrid_model.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py b/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py index e3bbd028f2..aed402eab1 100644 --- a/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py +++ b/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py @@ -29,29 +29,29 @@ def mamba_preprocess_weight(model_state_dict: dict): MAMBA_HYBRID_DICT = { # INPUT - 'embedding.word_embeddings.weight': TRTLLMLayers.vocab_embedding, + 'model.embedding.word_embeddings.weight': TRTLLMLayers.vocab_embedding, # ATTENTION - 'decoder.layers.self_attention.linear_qkv.weight': TRTLLMLayers.attention_qkv_weight, - 'decoder.layers.self_attention.linear_qkv.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, - 'decoder.layers.self_attention.linear_proj.weight': TRTLLMLayers.attention_dense_weight, + 'model.decoder.layers.self_attention.linear_qkv.weight': TRTLLMLayers.attention_qkv_weight, + 'model.decoder.layers.self_attention.linear_qkv.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, + 'model.decoder.layers.self_attention.linear_proj.weight': TRTLLMLayers.attention_dense_weight, # MLP - 'decoder.layers.mlp.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight, - 'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight, - 'decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, + 'model.decoder.layers.mlp.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight, + 'model.decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight, + 'model.decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, # Mixer - 'decoder.layers.mixer.dt_bias': TRTLLMLayers.mixer_dt_bias, - 'decoder.layers.mixer.A_log': TRTLLMLayers.mixer_A_log, - 'decoder.layers.mixer.D': TRTLLMLayers.mixer_D, - 'decoder.layers.mixer.in_proj.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, - 'decoder.layers.mixer.in_proj.weight': TRTLLMLayers.mixer_in_proj_weight, - 'decoder.layers.mixer.conv1d.weight':TRTLLMLayers.mixer_conv_weight, - 'decoder.layers.mixer.conv1d.bias': TRTLLMLayers.mixer_conv_bias, - 'decoder.layers.mixer.out_proj.weight': TRTLLMLayers.mixer_out_proj_weight, - 'decoder.layers.mixer.norm.weight': TRTLLMLayers.mixer_norm_weight, + 'model.decoder.layers.mixer.dt_bias': TRTLLMLayers.mixer_dt_bias, + 'model.decoder.layers.mixer.A_log': TRTLLMLayers.mixer_A_log, + 'model.decoder.layers.mixer.D': TRTLLMLayers.mixer_D, + 'model.decoder.layers.mixer.in_proj.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, + 'model.decoder.layers.mixer.in_proj.weight': TRTLLMLayers.mixer_in_proj_weight, + 'model.decoder.layers.mixer.conv1d.weight':TRTLLMLayers.mixer_conv_weight, + 'model.decoder.layers.mixer.conv1d.bias': TRTLLMLayers.mixer_conv_bias, + 'model.decoder.layers.mixer.out_proj.weight': TRTLLMLayers.mixer_out_proj_weight, + 'model.decoder.layers.mixer.norm.weight': TRTLLMLayers.mixer_norm_weight, # FINAL LAYER NORM - 'decoder.final_norm.weight': TRTLLMLayers.final_layernorm_weight, + 'model.decoder.final_norm.weight': TRTLLMLayers.final_layernorm_weight, # OUTPUT LAYER - 'output_layer.weight': TRTLLMLayers.lm_head, + 'model.output_layer.weight': TRTLLMLayers.lm_head, 'preprocess_weight': mamba_preprocess_weight, } From d483276eaea136e65f81e3df87edcbba7d2b2e17 Mon Sep 17 00:00:00 2001 From: Bobby Chen Date: Tue, 26 Nov 2024 17:45:27 -0800 Subject: [PATCH 3/8] Fix for new dict --- .../mamba_hybrid_model.py | 32 +++++++------------ megatron/core/export/trtllm/trtllm_helper.py | 15 ++++++++- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py b/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py index aed402eab1..1932717bc6 100644 --- a/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py +++ b/megatron/core/export/trtllm/model_to_trllm_mapping/mamba_hybrid_model.py @@ -28,30 +28,20 @@ def mamba_preprocess_weight(model_state_dict: dict): MAMBA_HYBRID_DICT = { - # INPUT - 'model.embedding.word_embeddings.weight': TRTLLMLayers.vocab_embedding, - # ATTENTION - 'model.decoder.layers.self_attention.linear_qkv.weight': TRTLLMLayers.attention_qkv_weight, - 'model.decoder.layers.self_attention.linear_qkv.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, - 'model.decoder.layers.self_attention.linear_proj.weight': TRTLLMLayers.attention_dense_weight, # MLP - 'model.decoder.layers.mlp.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight, - 'model.decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight, - 'model.decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, + 'decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, # Mixer - 'model.decoder.layers.mixer.dt_bias': TRTLLMLayers.mixer_dt_bias, - 'model.decoder.layers.mixer.A_log': TRTLLMLayers.mixer_A_log, - 'model.decoder.layers.mixer.D': TRTLLMLayers.mixer_D, - 'model.decoder.layers.mixer.in_proj.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, - 'model.decoder.layers.mixer.in_proj.weight': TRTLLMLayers.mixer_in_proj_weight, - 'model.decoder.layers.mixer.conv1d.weight':TRTLLMLayers.mixer_conv_weight, - 'model.decoder.layers.mixer.conv1d.bias': TRTLLMLayers.mixer_conv_bias, - 'model.decoder.layers.mixer.out_proj.weight': TRTLLMLayers.mixer_out_proj_weight, - 'model.decoder.layers.mixer.norm.weight': TRTLLMLayers.mixer_norm_weight, + 'decoder.layers.mixer.dt_bias': TRTLLMLayers.mixer_dt_bias, + 'decoder.layers.mixer.A_log': TRTLLMLayers.mixer_A_log, + 'decoder.layers.mixer.D': TRTLLMLayers.mixer_D, + 'decoder.layers.mixer.in_proj.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, + 'decoder.layers.mixer.in_proj.weight': TRTLLMLayers.mixer_in_proj_weight, + 'decoder.layers.mixer.conv1d.weight':TRTLLMLayers.mixer_conv_weight, + 'decoder.layers.mixer.conv1d.bias': TRTLLMLayers.mixer_conv_bias, + 'decoder.layers.mixer.out_proj.weight': TRTLLMLayers.mixer_out_proj_weight, + 'decoder.layers.mixer.norm.weight': TRTLLMLayers.mixer_norm_weight, # FINAL LAYER NORM - 'model.decoder.final_norm.weight': TRTLLMLayers.final_layernorm_weight, - # OUTPUT LAYER - 'model.output_layer.weight': TRTLLMLayers.lm_head, + 'decoder.final_norm.weight': TRTLLMLayers.final_layernorm_weight, 'preprocess_weight': mamba_preprocess_weight, } diff --git a/megatron/core/export/trtllm/trtllm_helper.py b/megatron/core/export/trtllm/trtllm_helper.py index 8c4aa998ce..0a0584d46d 100644 --- a/megatron/core/export/trtllm/trtllm_helper.py +++ b/megatron/core/export/trtllm/trtllm_helper.py @@ -11,6 +11,8 @@ from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import ( DEFAULT_CONVERSION_DICT, ) +from megatron.core.export.trtllm.model_to_trllm_mapping.mamba_hybrid_model import MAMBA_HYBRID_DICT + from megatron.core.export.trtllm.trt_model_config import TRT_MODEL_CONFIG from megatron.core.export.trtllm.trt_model_type import TRT_MODEL_TYPE_STRING @@ -74,7 +76,7 @@ def __init__( 'learned_absolute', 'rope', 'none', - ], f"Position embedding type should be one of learned_absolute, rope. You entered {position_embedding_type}" + ], f"Position embedding type should be one of learned_absolute, rope, none. You entered {position_embedding_type}" self.position_embedding_type = position_embedding_type self.max_position_embeddings = max_position_embeddings self.rotary_percentage = rotary_percentage @@ -88,6 +90,17 @@ def __init__( self.hybrid_override_pattern = hybrid_override_pattern self.weights_converter = None + if model_type == ModelType.mamba_hybrid: + mamba_hybrid_dict = MAMBA_HYBRID_DICT.copy() + for k in trtllm_conversion_dict: + if k.startswith('model'): + mamba_hybrid_dict = { + k if k == 'preprocess_weight' else f"model.{k}": v + for k, v in mamba_hybrid_dict.items() + } + break + self.trtllm_conversion_dict.update(mamba_hybrid_dict) + def _get_trtllm_config( self, export_config: ExportConfig, From bffbbe2d5b5f2a82b5767e31035e8b8a20add7b1 Mon Sep 17 00:00:00 2001 From: Bobby Chen Date: Tue, 26 Nov 2024 17:57:33 -0800 Subject: [PATCH 4/8] Add norm_type in config --- megatron/core/export/trtllm/trtllm_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/core/export/trtllm/trtllm_helper.py b/megatron/core/export/trtllm/trtllm_helper.py index 0a0584d46d..03ccc38d33 100644 --- a/megatron/core/export/trtllm/trtllm_helper.py +++ b/megatron/core/export/trtllm/trtllm_helper.py @@ -172,6 +172,7 @@ def _get_trtllm_config( 'tp_size': export_config.inference_tp_size, 'pp_size': export_config.inference_pp_size, 'gpus_per_node': gpus_per_node, + 'norm_type': self.transformer_config.normalization.lower(), } if self.model_type == ModelType.falcon: From 5953ad05a1d49563d50b5e978a75acb3c8f6efbf Mon Sep 17 00:00:00 2001 From: Bobby Chen Date: Wed, 27 Nov 2024 11:30:41 -0800 Subject: [PATCH 5/8] Add layer_types config --- megatron/core/export/trtllm/trtllm_helper.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/megatron/core/export/trtllm/trtllm_helper.py b/megatron/core/export/trtllm/trtllm_helper.py index 03ccc38d33..ef3fd918dc 100644 --- a/megatron/core/export/trtllm/trtllm_helper.py +++ b/megatron/core/export/trtllm/trtllm_helper.py @@ -192,6 +192,15 @@ def _get_trtllm_config( * config["state_size"]) config["use_bias"] = config["bias"] config["hybrid_override_pattern"] = self.hybrid_override_pattern + layer_types = [] + for k in config['hybrid_override_pattern']: + if k == '*': + layer_types.append('attention') + elif k == 'M': + layer_types.append('recurrent') + else: + layer_types.append('-') + config['layer_types'] = layer_types config["ssm_rmsnorm"] = True if self.seq_len_interpolation_factor is not None: From 94ee31a14c07d1ce4ce3c344ebc914db3988f7ae Mon Sep 17 00:00:00 2001 From: Bobby Chen Date: Mon, 2 Dec 2024 12:21:57 -0800 Subject: [PATCH 6/8] Add residual_in_fp32 --- megatron/core/export/trtllm/trtllm_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/core/export/trtllm/trtllm_helper.py b/megatron/core/export/trtllm/trtllm_helper.py index ef3fd918dc..6f1f048a2c 100644 --- a/megatron/core/export/trtllm/trtllm_helper.py +++ b/megatron/core/export/trtllm/trtllm_helper.py @@ -202,6 +202,7 @@ def _get_trtllm_config( layer_types.append('-') config['layer_types'] = layer_types config["ssm_rmsnorm"] = True + config["residual_in_fp32"] = False if self.seq_len_interpolation_factor is not None: config["rotary_scaling"] = { From a974348a884f9b64b56808665388c768daaace67 Mon Sep 17 00:00:00 2001 From: Bobby Chen Date: Fri, 10 Jan 2025 10:12:02 -0800 Subject: [PATCH 7/8] Add TP split --- ...e_device_trtllm_model_weights_converter.py | 85 +++++++++++++++---- 1 file changed, 68 insertions(+), 17 deletions(-) diff --git a/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py b/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py index 17a8129fd2..e10dd95c61 100644 --- a/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +++ b/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py @@ -63,6 +63,11 @@ def __init__( else: num_kv_heads = self.transformer_config.num_attention_heads self.num_kv_heads = num_kv_heads + # Config for Mamba hybrid + self.rnn_hidden_size = 2 * self.transformer_config.hidden_size + self.state_size = 128 + self.ngroups = 8 + self.rnn_head_size = 64 def _convert_non_transformer_layer(self, model_state_dict: dict, layer_name: str): """Convert Non Transformer layers to TRTLLM weights @@ -111,11 +116,12 @@ def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type= self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = ( split_val.to(self.storage_type).detach().contiguous() ) - elif split_type == 'conv': - val = val.unsqueeze(-1) - self.trtllm_model_weights[layer_name] = ( - val.to(self.storage_type).detach().contiguous() - ) + elif split_type == 'conv_weight': + for split_num, split_val in enumerate(val): + split_val = split_val.unsqueeze(-1) + self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = ( + split_val.to(self.storage_type).detach().contiguous() + ) else: if val.ndim >= 2: val = torch.transpose(val.reshape(val.shape[0], -1), 1, 0) @@ -135,12 +141,6 @@ def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type= or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias)) or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias)) or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight)) - or layer_name.endswith(suffix(TRTLLMLayers.mixer_dt_bias)) - or layer_name.endswith(suffix(TRTLLMLayers.mixer_D)) - or layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_bias)) - or layer_name.endswith(suffix(TRTLLMLayers.mixer_in_proj_weight)) - or layer_name.endswith(suffix(TRTLLMLayers.mixer_norm_weight)) - or layer_name.endswith(suffix(TRTLLMLayers.mixer_out_proj_weight)) ): # Same as layernorm1p in NeMo if ( @@ -152,9 +152,14 @@ def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type= _add_to_trtllm_model_weights(val=val, layer_name=layer_name, split_type=None) - elif layer_name.endswith( - suffix(TRTLLMLayers.attention_dense_weight) - ) or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight)): + elif ( + layer_name.endswith(suffix(TRTLLMLayers.attention_dense_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_norm_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_dt_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_D)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_out_proj_weight)) + ): split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=0) _add_to_trtllm_model_weights( val=split_vals, layer_name=layer_name, split_type='tensor_split' @@ -274,10 +279,56 @@ def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type= ) elif layer_name.endswith(suffix(TRTLLMLayers.mixer_A_log)): val = -torch.exp(val.float()) - _add_to_trtllm_model_weights(val=val, layer_name=layer_name, split_type=None) - elif layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_weight)): + split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=0) + _add_to_trtllm_model_weights( + val=split_vals, layer_name=layer_name, split_type='tensor_split' + ) + elif ( + layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_bias)) + ): + bc_num = self.state_size * self.ngroups + + xBC = torch.split(val, [self.rnn_hidden_size, bc_num, bc_num], dim=0) + x_split = torch.chunk(xBC[0], self.export_config.inference_tp_size, axis=0) + b_split = torch.chunk(xBC[1], self.export_config.inference_tp_size, axis=0) + c_split = torch.chunk(xBC[2], self.export_config.inference_tp_size, axis=0) + + split_vals = [ + torch.concatenate(item, dim=0) + for item in zip(x_split, b_split, c_split) + ] + + if layer_name.endswith("weight"): + _add_to_trtllm_model_weights( + val=split_vals, layer_name=layer_name, split_type='conv_weight' + ) + else: + _add_to_trtllm_model_weights( + val=split_vals, layer_name=layer_name, split_type='tensor_split' + ) + elif layer_name.endswith(suffix(TRTLLMLayers.mixer_in_proj_weight)): + bc_num = self.state_size * self.ngroups + dt_num = self.rnn_hidden_size // self.rnn_head_size + + in_proj = torch.split( + val, + [self.rnn_hidden_size, self.rnn_hidden_size, bc_num, bc_num, dt_num], + dim=1 + ) + + z_split = torch.chunk(in_proj[0], self.export_config.inference_tp_size, axis=1) + x_split = torch.chunk(in_proj[1], self.export_config.inference_tp_size, axis=1) + b_split = torch.chunk(in_proj[2], self.export_config.inference_tp_size, axis=1) + c_split = torch.chunk(in_proj[3], self.export_config.inference_tp_size, axis=1) + dt_split = torch.chunk(in_proj[4], self.export_config.inference_tp_size, axis=1) + + split_vals = [ + torch.concatenate(item, dim=1) + for item in zip(z_split, x_split, b_split, c_split, dt_split) + ] _add_to_trtllm_model_weights( - val=val, layer_name=layer_name, split_type='conv' + val=split_vals, layer_name=layer_name, split_type='tensor_split' ) else: raise ValueError(f"{layer_name} cannot be handled by converter") From c0ec8f6af79b837b9322ffc42dcc00d53499c187 Mon Sep 17 00:00:00 2001 From: Bobby Chen Date: Thu, 23 Jan 2025 10:58:00 -0800 Subject: [PATCH 8/8] Add distributed changes --- megatron/core/export/trtllm/trtllm_helper.py | 2 ++ ...tributed_trtllm_model_weights_converter.py | 36 +++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/megatron/core/export/trtllm/trtllm_helper.py b/megatron/core/export/trtllm/trtllm_helper.py index 6f1f048a2c..07176e33ca 100644 --- a/megatron/core/export/trtllm/trtllm_helper.py +++ b/megatron/core/export/trtllm/trtllm_helper.py @@ -343,6 +343,8 @@ def _get_trtllm_pretrained_config_and_model_weights_in_distributed_setting( pp_size=export_config.inference_pp_size, ) + self.weights_converter.rename_weight(trtllm_model_config) + return self.weights_converter.trtllm_model_weights, trtllm_model_config def _get_trtllm_pretrained_config_and_model_weights_list_on_single_device( diff --git a/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py b/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py index d50f5a3e04..4c57a23fcb 100644 --- a/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +++ b/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py @@ -53,6 +53,11 @@ def __init__( else: num_kv_heads = self.transformer_config.num_attention_heads self.num_kv_heads = num_kv_heads + # Config for Mamba hybrid + self.rnn_hidden_size = 2 * self.transformer_config.hidden_size + self.state_size = 128 + self.ngroups = 8 + self.rnn_head_size = 64 self.inference_pp_size = parallel_state.get_pipeline_model_parallel_world_size() self.inference_tp_size = parallel_state.get_tensor_model_parallel_world_size() @@ -65,11 +70,11 @@ def __init__( vp_size is None or vp_size == 1 ), "Virtual parallelism is not supported in GPU Converter. Gather the VP chunks and use PP config." - def _add_to_trtllm_model_weights(self, val: torch.Tensor, layer_name: str): + def _add_to_trtllm_model_weights(self, val: torch.Tensor, layer_name: str, transpose=True): assert torch.is_tensor(val), f"Expected a tensor for {layer_name} but got {type(val)}" val = val.to(self.storage_type) val = val.detach().contiguous() - if val.ndim >= 2: + if val.ndim >= 2 and transpose: val = torch.transpose(val.reshape(val.shape[0], -1), 0, 1) if layer_name not in self.trtllm_model_weights: self.trtllm_model_weights[layer_name] = torch.empty( @@ -100,6 +105,11 @@ def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor): or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight)) or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_weight)) or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_norm_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_dt_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_D)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_out_proj_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_in_proj_weight)) ): # Same as layernorm1p in NeMo if ( @@ -168,6 +178,17 @@ def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor): ) self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name) + elif layer_name.endswith(suffix(TRTLLMLayers.mixer_A_log)): + val = -torch.exp(val.float()) + self._add_to_trtllm_model_weights(val=val, layer_name=layer_name) + + elif ( + layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_bias)) + ): + val = val.unsqueeze(-1) + self._add_to_trtllm_model_weights(val=val, layer_name=layer_name, transpose=False) + else: raise ValueError(f"{layer_name} cannot be handled by GPU converter") @@ -265,3 +286,14 @@ def convert( model_state_dict.items(), desc="Converting to TRTLLM Weights" ): self._convert_transformer_layer(layer_name, value) + + def rename_weight(self, trtllm_model_config: dict): + is_mamba = hasattr(trtllm_model_config, "mamba_version") + if not is_mamba: + return + for layer_name in list(self.trtllm_model_weights.keys()): + new_key = layer_name.replace("transformer", "backbone") + new_key = new_key.replace("mlp", "layer") + new_key = new_key.replace("attention", "layer") + self.trtllm_model_weights[new_key] = self.trtllm_model_weights[layer_name] + del self.trtllm_model_weights[layer_name]