Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mamba TRTLLM support #1320

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/core/export/model_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
from enum import Enum

ModelType = Enum(
'ModelType', ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma"]
'ModelType', ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma", "mamba_hybrid"]
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import torch
from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers

def mamba_preprocess_weight(model_state_dict: dict):
for k in list(model_state_dict.keys()):
if 'mixer.in_proj.weight' in k:
if k[-1] == 'z':
prefix = k[:-2]
z = model_state_dict.pop(k)
x = model_state_dict.pop(f"{prefix}.x")
B = model_state_dict.pop(f"{prefix}.B")
C = model_state_dict.pop(f"{prefix}.C")
dt = model_state_dict.pop(f"{prefix}.dt")
model_state_dict[prefix] = torch.concatenate(
[z, x, B, C, dt], dim=0
)
elif 'conv1d' in k:
if k[-1] == 'x':
prefix = k[:-2]
x = model_state_dict.pop(k)
B = model_state_dict.pop(f"{prefix}.B")
C = model_state_dict.pop(f"{prefix}.C")
model_state_dict[prefix] = torch.concatenate(
[x, B, C], dim=0
)


MAMBA_HYBRID_DICT = {
# MLP
'decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.input_layernorm_weight,
# Mixer
'decoder.layers.mixer.dt_bias': TRTLLMLayers.mixer_dt_bias,
'decoder.layers.mixer.A_log': TRTLLMLayers.mixer_A_log,
'decoder.layers.mixer.D': TRTLLMLayers.mixer_D,
'decoder.layers.mixer.in_proj.layer_norm_weight': TRTLLMLayers.input_layernorm_weight,
'decoder.layers.mixer.in_proj.weight': TRTLLMLayers.mixer_in_proj_weight,
'decoder.layers.mixer.conv1d.weight':TRTLLMLayers.mixer_conv_weight,
'decoder.layers.mixer.conv1d.bias': TRTLLMLayers.mixer_conv_bias,
'decoder.layers.mixer.out_proj.weight': TRTLLMLayers.mixer_out_proj_weight,
'decoder.layers.mixer.norm.weight': TRTLLMLayers.mixer_norm_weight,
# FINAL LAYER NORM
'decoder.final_norm.weight': TRTLLMLayers.final_layernorm_weight,

'preprocess_weight': mamba_preprocess_weight,
}
1 change: 1 addition & 0 deletions megatron/core/export/trtllm/trt_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
ModelType.llama: tensorrt_llm.models.llama.config.LLaMAConfig,
ModelType.gemma: tensorrt_llm.models.GemmaConfig,
ModelType.falcon: tensorrt_llm.models.falcon.config.FalconConfig,
ModelType.mamba_hybrid: tensorrt_llm.models.mamba_hybrid.config.MambaHybridConfig,
}
1 change: 1 addition & 0 deletions megatron/core/export/trtllm/trt_model_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
ModelType.llama: 'LlamaForCausalLM',
ModelType.gemma: 'GemmaForCausalLM',
ModelType.falcon: 'FalconForCausalLM',
ModelType.mamba_hybrid: 'MambaHybridForCausalLM',
}
44 changes: 43 additions & 1 deletion megatron/core/export/trtllm/trtllm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import (
DEFAULT_CONVERSION_DICT,
)
from megatron.core.export.trtllm.model_to_trllm_mapping.mamba_hybrid_model import MAMBA_HYBRID_DICT

from megatron.core.export.trtllm.trt_model_config import TRT_MODEL_CONFIG
from megatron.core.export.trtllm.trt_model_type import TRT_MODEL_TYPE_STRING

Expand Down Expand Up @@ -42,6 +44,7 @@ def __init__(
seq_len_interpolation_factor: float = None,
moe_renorm_mode=None,
share_embeddings_and_output_weights=False,
hybrid_override_pattern: str=None,
):
"""Constructor for the TRTLLMHelper

Expand Down Expand Up @@ -72,7 +75,8 @@ def __init__(
assert position_embedding_type in [
'learned_absolute',
'rope',
], f"Position embedding type should be one of learned_absolute, rope. You entered {position_embedding_type}"
'none',
], f"Position embedding type should be one of learned_absolute, rope, none. You entered {position_embedding_type}"
self.position_embedding_type = position_embedding_type
self.max_position_embeddings = max_position_embeddings
self.rotary_percentage = rotary_percentage
Expand All @@ -83,8 +87,20 @@ def __init__(
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.moe_renorm_mode = moe_renorm_mode
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.hybrid_override_pattern = hybrid_override_pattern
self.weights_converter = None

if model_type == ModelType.mamba_hybrid:
mamba_hybrid_dict = MAMBA_HYBRID_DICT.copy()
for k in trtllm_conversion_dict:
if k.startswith('model'):
mamba_hybrid_dict = {
k if k == 'preprocess_weight' else f"model.{k}": v
for k, v in mamba_hybrid_dict.items()
}
break
self.trtllm_conversion_dict.update(mamba_hybrid_dict)

def _get_trtllm_config(
self,
export_config: ExportConfig,
Expand Down Expand Up @@ -156,13 +172,37 @@ def _get_trtllm_config(
'tp_size': export_config.inference_tp_size,
'pp_size': export_config.inference_pp_size,
'gpus_per_node': gpus_per_node,
'norm_type': self.transformer_config.normalization.lower(),
}

if self.model_type == ModelType.falcon:
config["new_decoder_architecture"] = (
False if self.transformer_config.num_layers == 32 else True
)
config["parallel_attention"] = True
elif self.model_type == ModelType.mamba_hybrid:
config["mamba_version"] = "Mamba2"
config["rnn_hidden_size"] = 2 * self.transformer_config.hidden_size
config["state_size"] = 128
config["conv_kernel"] = 4
config["rnn_head_size"] = 64
config["ngroups"] = 8
config["chunk_size"] = 128
config["rnn_conv_dim_size"] = (config["rnn_hidden_size"] + 2 * config["ngroups"]
* config["state_size"])
config["use_bias"] = config["bias"]
config["hybrid_override_pattern"] = self.hybrid_override_pattern
layer_types = []
for k in config['hybrid_override_pattern']:
if k == '*':
layer_types.append('attention')
elif k == 'M':
layer_types.append('recurrent')
else:
layer_types.append('-')
config['layer_types'] = layer_types
config["ssm_rmsnorm"] = True
config["residual_in_fp32"] = False

if self.seq_len_interpolation_factor is not None:
config["rotary_scaling"] = {
Expand Down Expand Up @@ -303,6 +343,8 @@ def _get_trtllm_pretrained_config_and_model_weights_in_distributed_setting(
pp_size=export_config.inference_pp_size,
)

self.weights_converter.rename_weight(trtllm_model_config)

return self.weights_converter.trtllm_model_weights, trtllm_model_config

def _get_trtllm_pretrained_config_and_model_weights_list_on_single_device(
Expand Down
10 changes: 10 additions & 0 deletions megatron/core/export/trtllm/trtllm_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ class TRTLLMLayers(Enum):
mlp_projection_weight = 'transformer.layers.mlp.proj.weight'
mlp_projection_bias = 'transformer.layers.mlp.proj.bias'

# mamba layers
mixer_dt_bias = 'transformer.layers.layer.dt_bias'
mixer_A_log = 'transformer.layers.layer.A'
mixer_D = 'transformer.layers.layer.D'
mixer_in_proj_weight = 'transformer.layers.layer.in_proj.weight'
mixer_out_proj_weight = 'transformer.layers.layer.out_proj.weight'
mixer_conv_weight = 'transformer.layers.layer.conv1d.weight'
mixer_conv_bias = 'transformer.layers.layer.conv1d.bias'
mixer_norm_weight = 'transformer.layers.layer.norm.weight'

# mixture of expert layers
mlp_router_weight = 'transformer.layers.mlp.router.weight'
mlp_fc_weight_mixture_of_experts = 'transformer.layers.mlp.fc.weight.expert'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def __init__(
else:
num_kv_heads = self.transformer_config.num_attention_heads
self.num_kv_heads = num_kv_heads
# Config for Mamba hybrid
self.rnn_hidden_size = 2 * self.transformer_config.hidden_size
self.state_size = 128
self.ngroups = 8
self.rnn_head_size = 64

self.inference_pp_size = parallel_state.get_pipeline_model_parallel_world_size()
self.inference_tp_size = parallel_state.get_tensor_model_parallel_world_size()
Expand All @@ -65,11 +70,11 @@ def __init__(
vp_size is None or vp_size == 1
), "Virtual parallelism is not supported in GPU Converter. Gather the VP chunks and use PP config."

def _add_to_trtllm_model_weights(self, val: torch.Tensor, layer_name: str):
def _add_to_trtllm_model_weights(self, val: torch.Tensor, layer_name: str, transpose=True):
assert torch.is_tensor(val), f"Expected a tensor for {layer_name} but got {type(val)}"
val = val.to(self.storage_type)
val = val.detach().contiguous()
if val.ndim >= 2:
if val.ndim >= 2 and transpose:
val = torch.transpose(val.reshape(val.shape[0], -1), 0, 1)
if layer_name not in self.trtllm_model_weights:
self.trtllm_model_weights[layer_name] = torch.empty(
Expand Down Expand Up @@ -100,6 +105,11 @@ def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor):
or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mixer_norm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mixer_dt_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mixer_D))
or layer_name.endswith(suffix(TRTLLMLayers.mixer_out_proj_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mixer_in_proj_weight))
):
# Same as layernorm1p in NeMo
if (
Expand Down Expand Up @@ -168,6 +178,17 @@ def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor):
)
self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name)

elif layer_name.endswith(suffix(TRTLLMLayers.mixer_A_log)):
val = -torch.exp(val.float())
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)

elif (
layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mixer_conv_bias))
):
val = val.unsqueeze(-1)
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name, transpose=False)

else:
raise ValueError(f"{layer_name} cannot be handled by GPU converter")

Expand Down Expand Up @@ -265,3 +286,14 @@ def convert(
model_state_dict.items(), desc="Converting to TRTLLM Weights"
):
self._convert_transformer_layer(layer_name, value)

def rename_weight(self, trtllm_model_config: dict):
is_mamba = hasattr(trtllm_model_config, "mamba_version")
if not is_mamba:
return
for layer_name in list(self.trtllm_model_weights.keys()):
new_key = layer_name.replace("transformer", "backbone")
new_key = new_key.replace("mlp", "layer")
new_key = new_key.replace("attention", "layer")
self.trtllm_model_weights[new_key] = self.trtllm_model_weights[layer_name]
del self.trtllm_model_weights[layer_name]
Loading