diff --git a/LICENSE b/LICENSE index b49c04ee33..4782df586e 100644 --- a/LICENSE +++ b/LICENSE @@ -30,12 +30,13 @@ The following applies to all files unless otherwise noted: This repository also contains code from Hugging Face Inc., Google Research, Facebook (from their Fairseq, Dino, and ParlAI projects), Microsoft (from their -Swin-Transformer project) and Philip Popien. Files from these -organizations have notices at the top of each file. Below are -licenses used in those files, as indicated. +Swin-Transformer project), Philip Popien, and the Mamba project (Tri Dao and +Albert Gu). Files from these organizations have notices at the top of each file. +Below are licenses used in those files, as indicated. -------------- LICENSE FOR Facebook, huggingface, Google Research and LLaVA code -------------- +-------------------------------------------------------------------------------- +-- LICENSE FOR Facebook, huggingface, Google Research, LLaVA, and Mamba code -- Apache License diff --git a/README.md b/README.md index f2e4fe84b1..ba678f94f3 100644 --- a/README.md +++ b/README.md @@ -247,7 +247,6 @@ In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to config With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs. - ## Retro and InstructRetro @@ -270,6 +269,10 @@ In this repo, we provide an end-to-end reproduction guide to implement Retro and Please see [tools/retro/README.md](tools/retro/README.md) for a detailed overview. +## Mamba-based Language Models + +Please see [examples/mamba](./examples/mamba) for details. + b pd l + xBC = rearrange(xBC, "l b d -> b d l") + xBC = xBC.contiguous() + + # Compute short convolution + if conv_state is not None: + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + conv_state.copy_(F.pad(xBC, (self.d_conv - xBC.shape[-1], 0))) # Update state (B D W) + + seqlen = xBC.size(2) + if causal_conv1d_fn is None: + xBC = self.act(self.conv1d(xBC)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + xBC = causal_conv1d_fn( + x=xBC, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # transpose b pd l --> l b pd + xBC = rearrange(xBC, "b d l -> l b d") + xBC = xBC.contiguous() + + x, B, C = torch.split( + xBC, + [ + self.d_inner_local, + self.ngroups_local * self.d_state, + self.ngroups_local * self.d_state, + ], + dim=-1, + ) + + # TODO Vijay: fuse most of the transposes with the GEMMS + x = rearrange(x, "l b (h p) -> b l h p", p=self.headdim).contiguous() + dt = rearrange(dt, "l b d -> b l d").contiguous() + B = rearrange(B, "l b (g n) -> b l g n", n=self.d_state).contiguous() + C = rearrange(C, "l b (g n) -> b l g n", n=self.d_state).contiguous() + z = rearrange(z, "l b (h p) -> b l h p", p=self.headdim).contiguous() + y = mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + self.chunk_size, + D=rearrange(self.D.float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.D, + z=z if not self.rmsnorm else None, + dt_bias=self.dt_bias.float(), + dt_softplus=True, + return_final_states=ssm_state is not None, + ) + + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + + if self.rmsnorm: + y = rearrange(y, "b l h p -> b l (h p)").contiguous() + z = rearrange(z, "b l h p -> b l (h p)").contiguous() + y = self.norm(y, z) + y = rearrange(y, "b l d -> l b d").contiguous() + else: + y = rearrange(y, "b l h p -> l b (h p)").contiguous() + + # l b pd --> pl b d + out_full = y @ self.out_proj.weight.t() + if self.config.sequence_parallel: + out = reduce_scatter_to_sequence_parallel_region(out_full) + else: + out = reduce_from_tensor_model_parallel_region(out_full) + return out + + def step(self, hidden_states, conv_state, ssm_state): + # assert self.ngroups_local == 1, "Only support ngroups=1 for inference for now" + dtype = hidden_states.dtype + assert hidden_states.shape[0] == 1, "Only support decoding with 1 token at a time for now" + + # l b d --> b d + hidden_states = hidden_states.squeeze(0) + + # b d_model --> b p(2d) + xz = hidden_states @ self.in_proj.weight.t() + + z, xBC, dt = torch.split( + xz, + [ + self.d_inner_local, + self.d_inner_local + 2 * self.ngroups_local * self.d_state, + self.nheads_local, + ], + dim=-1, + ) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum( + conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 + ) # (B D) + if self.conv1d.bias is not None: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(dtype=dtype) + else: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x, B, C = torch.split( + xBC, + [ + self.d_inner_local, + self.ngroups_local * self.d_state, + self.ngroups_local * self.d_state, + ], + dim=-1, + ) + A = -torch.exp(self.A_log.float()) + + # SSM step + if selective_state_update is None: + if self.ngroups_local > 1: + B = rearrange(B, "b (g n) -> b g n", n=self.d_state) + C = rearrange(C, "b (g n) -> b g n", n=self.d_state) + B = repeat(B, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) + C = repeat(C, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) + + dt = repeat(dt, "b h -> b (h p)", p=self.headdim) + dt_bias = repeat(self.dt_bias, "h -> (h p)", p=self.headdim) + A = repeat(A, "h -> (h p) n", p=self.headdim, n=self.d_state) + D = repeat(self.D, "h -> (h p)", p=self.headdim) + + dt = F.softplus(dt + dt_bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + + dB_x = torch.einsum('bd,bdn,bd->bdn', dt, B, x) + ssm_state.copy_( + ssm_state * rearrange(dA, "b (h p) n -> b h p n", p=self.headdim) + + rearrange(dB_x, "b (h p) n -> b h p n", p=self.headdim) + ) + + y = torch.einsum( + "bdn,bdn->bd", + rearrange(ssm_state.to(dtype), "b h p n -> b (h p) n", p=self.headdim), + C, + ) + y = y + D.to(dtype) * x + if not self.rmsnorm: + y = y * self.act(z) # (B D) + else: + # Discretize A and B (b (g n)) + dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) + dA = torch.exp(dt * A) + x = rearrange(x, "b (h p) -> b h p", p=self.headdim) + dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) + ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) + y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) + y = y + rearrange(self.D.to(dtype), "h -> h 1") * x + y = rearrange(y, "b h p -> b (h p)") + if not self.rmsnorm: + y = y * self.act(z) # (B D) + else: + A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) + dt = repeat(dt, "b h -> b h p", p=self.headdim) + dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) + D = repeat(self.D, "h -> h p", p=self.headdim) + B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local) + C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local) + x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) + if not self.rmsnorm: + z = rearrange(z, "b (h p) -> b h p", p=self.headdim) + y = selective_state_update( + ssm_state, + x_reshaped, + dt, + A, + B, + C, + D, + z=z if not self.rmsnorm else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + y = rearrange(y, "b h p -> b (h p)") + + if self.rmsnorm: + y = self.norm(y, z) + + # b pd --> b d + out = y @ self.out_proj.weight.t() + out = reduce_from_tensor_model_parallel_region(out) + return out.unsqueeze(0), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype + ) + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, + self.nheads_local, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + conv_state = torch.zeros( + batch_size, + self.conv1d.weight.shape[0], + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.nheads_local, + self.headdim, + self.d_state, + device=self.in_proj.weight.device, + dtype=self.in_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state diff --git a/megatron/core/ssm/triton_cache_manager.py b/megatron/core/ssm/triton_cache_manager.py new file mode 100644 index 0000000000..43b5b34f39 --- /dev/null +++ b/megatron/core/ssm/triton_cache_manager.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import os +import socket +from pathlib import Path + +import torch + +try: + from triton.runtime.cache import FileCacheManager +except ImportError: + raise ImportError("triton is required by the Mamba model but cannot be imported") + + +def get_rank(): + return torch.distributed.get_rank() + + +def default_cache_dir(): + return os.path.join(Path.home(), ".triton", "cache") + + +class ParallelFileCacheManager(FileCacheManager): + + # See https://github.com/triton-lang/triton/blob/main/python/triton/runtime/cache.py + + # When running Triton with multiple ranks, they each create their own cache manager. Their input + # keys to that class are mostly (but not entirely) the same across ranks, which leads many ranks + # to write to the same 'key' directories in the cache dir at the same time during compilation, + # leading to conflicts. This works around that by making each cache dir be rank specific by + # adding "rank__" to the cache directory. + + def __init__(self, key): + self.key = key + self.lock_path = None + # create cache directory if it doesn't exist + self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) + self.cache_dir = os.path.join( + self.cache_dir, "rank_{}_{}".format(socket.gethostname(), os.getpid()) + ) + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py index 6b0aa59839..87f32a56a3 100644 --- a/megatron/core/tensor_parallel/__init__.py +++ b/megatron/core/tensor_parallel/__init__.py @@ -19,7 +19,9 @@ gather_from_sequence_parallel_region, gather_from_sequence_parallel_region_to_moe, gather_from_tensor_model_parallel_region, + reduce_from_tensor_model_parallel_region, reduce_scatter_last_dim_to_tensor_parallel_region, + reduce_scatter_to_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region_from_moe, scatter_to_sequence_parallel_region, scatter_to_tensor_model_parallel_region, @@ -54,7 +56,8 @@ "copy_to_tensor_model_parallel_region", "gather_from_tensor_model_parallel_region", "gather_from_sequence_parallel_region", - # "reduce_from_tensor_model_parallel_region", + "reduce_from_tensor_model_parallel_region", + "reduce_scatter_to_sequence_parallel_region", "scatter_to_tensor_model_parallel_region", "scatter_to_sequence_parallel_region", # random.py diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py index efc901fb0e..88e77541d1 100644 --- a/megatron/core/tensor_parallel/mappings.py +++ b/megatron/core/tensor_parallel/mappings.py @@ -22,7 +22,7 @@ def _reduce(input_): return input_ # All-reduce. - torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) + torch.distributed.all_reduce(input_.contiguous(), group=get_tensor_model_parallel_group()) return input_ diff --git a/megatron/inference/text_generation/tokenization.py b/megatron/inference/text_generation/tokenization.py index cab2d2ea5a..8532be9621 100644 --- a/megatron/inference/text_generation/tokenization.py +++ b/megatron/inference/text_generation/tokenization.py @@ -32,6 +32,7 @@ def detokenize_generations(tokens_gpu_tensor, for token in sequence_tokens: if args.tokenizer_type in ['SentencePieceTokenizer', 'GPTSentencePieceTokenizer', + 'HuggingFaceTokenizer', 'Llama2Tokenizer', 'MistralTokenizer']: word = tokenizer.decoder[token] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index a0fe8e0f4c..47b6c9f7ef 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -749,7 +749,7 @@ def _add_network_size_args(parser): help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') group.add_argument('--position-embedding-type', type=str, default='learned_absolute', - choices=['learned_absolute', 'rope'], + choices=['learned_absolute', 'rope', 'none'], help='Position embedding type.') group.add_argument('--use-rotary-position-embeddings', action='store_true', help='Use rotary positional embeddings or not. ' @@ -1186,14 +1186,21 @@ def _add_learning_rate_args(parser): 'and initial warmup, the learning rate at each ' 'iteration would be different.') group.add_argument('--lr-decay-style', type=str, default='linear', - choices=['constant', 'linear', 'cosine', 'inverse-square-root'], + choices=['constant', 'linear', 'cosine', 'inverse-square-root', 'WSD'], help='Learning rate decay function.') + group.add_argument('--lr-wsd-decay-style', type=str, default='exponential', + choices=['exponential', 'linear', 'cosine'], + help='Decay style for the annealing phase of WSD'), group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' ' If None defaults to `--train-iters`') group.add_argument('--lr-decay-samples', type=int, default=None, help='number of samples to decay learning rate over,' ' If None defaults to `--train-samples`') + group.add_argument('--lr-wsd-decay-samples', type=int, default=None, + help='number of samples for the annealing phase in the wsd schedule') + group.add_argument('--lr-wsd-decay-iters', type=int, default=None, + help='number of iterations for the annealing phase in the wsd schedule') group.add_argument('--lr-warmup-fraction', type=float, default=None, help='fraction of lr-warmup-(iters/samples) to use ' 'for warmup (as a float)') @@ -1488,6 +1495,7 @@ def _add_data_args(parser): 'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer', + 'HuggingFaceTokenizer', 'Llama2Tokenizer', 'Llama3Tokenizer', 'MistralTokenizer', @@ -1700,6 +1708,18 @@ def _add_experimental_args(parser): 'To use local spec specify local as the argument.' 'For more details, see the model class, ' '`transformer_block.py`, or `transformer_layer.py`') + group.add_argument('--hybrid-attention-ratio', type=float, default=0.0, + help='Ratio of attention layers to total layers, in the ' + 'range [0.0, 1.0].') + group.add_argument('--hybrid-mlp-ratio', type=float, default=0.0, + help='Ratio of mlp layers to total layers, in the ' + 'range [0.0, 1.0].') + group.add_argument('--hybrid-override-pattern', type=str, default=None, + help='Force a specific hybrid layer pattern. If a value' + 'greater than 0.0 is supplied to any of the hybrid ratio' + 'arguments, then the number of each type of layer in the' + 'override pattern must match number in the overidden' + 'pattern') group.add_argument('--yaml-cfg', type=str, default=None, help = 'Config file to add additional arguments') diff --git a/megatron/training/optimizer_param_scheduler.py b/megatron/training/optimizer_param_scheduler.py index 54a45ef098..409e1dbc7d 100644 --- a/megatron/training/optimizer_param_scheduler.py +++ b/megatron/training/optimizer_param_scheduler.py @@ -13,7 +13,9 @@ def __init__(self, optimizer, init_lr, max_lr, min_lr, lr_warmup_steps, lr_decay_steps, lr_decay_style, start_wd, end_wd, wd_incr_steps, wd_incr_style, use_checkpoint_opt_param_scheduler=True, - override_opt_param_scheduler=False): + override_opt_param_scheduler=False, + wsd_decay_steps=None, + lr_wsd_decay_style=None): # Class values. self.optimizer = optimizer @@ -28,10 +30,14 @@ def __init__(self, optimizer, init_lr, max_lr, min_lr, self.lr_warmup_steps = lr_warmup_steps self.num_steps = 0 self.lr_decay_steps = lr_decay_steps + self.wsd_decay_steps = wsd_decay_steps + self.lr_wsd_decay_style = lr_wsd_decay_style assert self.lr_decay_steps > 0 assert self.lr_warmup_steps < self.lr_decay_steps self.lr_decay_style = lr_decay_style + if self.lr_decay_style == "WSD": + assert self.wsd_decay_steps is not None self.start_wd = start_wd self.end_wd = end_wd @@ -120,6 +126,19 @@ def get_lr(self, param_group): coeff = (1.0 - decay_ratio) elif self.lr_decay_style == 'cosine': coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + elif self.lr_decay_style == 'WSD': + wsd_anneal_start_ = self.lr_decay_steps - self.wsd_decay_steps + if self.num_steps <= wsd_anneal_start_: + coeff = 1.0 + else: + wsd_steps = self.num_steps - wsd_anneal_start_ + wsd_decay_ratio = float(wsd_steps) / float(self.wsd_decay_steps) + if self.lr_wsd_decay_style == "linear": + coeff = (1.0 - wsd_decay_ratio) + elif self.lr_wsd_decay_style == "cosine": + coeff = 0.5 * (math.cos(math.pi * wsd_decay_ratio) + 1.0) + elif self.lr_wsd_decay_style == "exponential": + coeff = ((2.0 * math.pow(0.5, wsd_decay_ratio)) - 1.0) else: raise Exception('{} decay style is not supported.'.format( self.lr_decay_style)) diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py index b5953a5c6c..b88909eea3 100644 --- a/megatron/training/tokenizer/tokenizer.py +++ b/megatron/training/tokenizer/tokenizer.py @@ -38,6 +38,8 @@ def build_tokenizer(args): elif args.tokenizer_type == 'GPTSentencePieceTokenizer': assert args.tokenizer_model is not None tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model) + elif args.tokenizer_type == 'HuggingFaceTokenizer': + tokenizer = _HuggingFaceTokenizer(args.tokenizer_model) elif args.tokenizer_type == 'Llama2Tokenizer': assert args.tokenizer_model is not None tokenizer = _Llama2Tokenizer(args.tokenizer_model) @@ -78,6 +80,48 @@ def _vocab_size_with_padding(orig_vocab_size, args): return after +class _HuggingFaceTokenizer(MegatronTokenizer): + def __init__(self, pretrained_model_name_or_path): + super().__init__(pretrained_model_name_or_path) + try: + import transformers + except ImportError: + raise EnvironmentError(f"The transformers library must be installed to use huggingface_tokenizer_provider") + + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path) + self._vocab = self._tokenizer.get_vocab() + self._inv_vocab = {token_id: token for token, token_id in self._vocab.items()} + + @property + def vocab_size(self): + return len(self._tokenizer) + + @property + def vocab(self): + """Dictionary from vocab text token to id token.""" + return self._vocab + + @property + def inv_vocab(self): + """Dictionary from vocab id token to text token.""" + return self._inv_vocab + + @property + def decoder(self): + return self._inv_vocab + + def tokenize(self, text): + return self._tokenizer(text).input_ids + + def detokenize(self, token_ids): + return self._tokenizer.decode(token_ids) + + @property + def eod(self): + return self._tokenizer.eos_token_id + + class _BertWordPieceTokenizer(MegatronTokenizer): """Original BERT wordpiece tokenizer.""" diff --git a/megatron/training/training.py b/megatron/training/training.py index 8c12268d24..3b6c437be5 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -449,6 +449,9 @@ def get_optimizer_param_scheduler(optimizer): args.lr_decay_iters = args.train_iters lr_decay_steps = args.lr_decay_iters * args.global_batch_size wd_incr_steps = args.train_iters * args.global_batch_size + wsd_decay_steps = None + if args.lr_wsd_decay_iters is not None: + wsd_decay_steps = args.lr_wsd_decay_iters * args.global_batch_size if args.lr_warmup_fraction is not None: lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps else: @@ -463,6 +466,7 @@ def get_optimizer_param_scheduler(optimizer): args.lr_decay_samples = args.train_samples lr_decay_steps = args.lr_decay_samples wd_incr_steps = args.train_samples + wsd_decay_steps = args.lr_wsd_decay_samples if args.lr_warmup_fraction is not None: lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps else: @@ -484,7 +488,9 @@ def get_optimizer_param_scheduler(optimizer): wd_incr_steps=wd_incr_steps, wd_incr_style=args.weight_decay_incr_style, use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler, - override_opt_param_scheduler=args.override_opt_param_scheduler) + override_opt_param_scheduler=args.override_opt_param_scheduler, + wsd_decay_steps=wsd_decay_steps, + lr_wsd_decay_style=args.lr_wsd_decay_style) return opt_param_scheduler diff --git a/pretrain_mamba.py b/pretrain_mamba.py new file mode 100644 index 0000000000..f2dbb97e67 --- /dev/null +++ b/pretrain_mamba.py @@ -0,0 +1,239 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Pretrain Mamba.""" + +import os +import torch +from functools import partial + +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.core import mpu +# from megatron.core import parallel_state +from megatron.core.enums import ModelType +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset +from megatron.core.models.mamba import MambaModel +from megatron.training import pretrain +from megatron.core.utils import StragglerDetector +from megatron.core.transformer.spec_utils import import_module +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, +) +from megatron.training.arguments import core_transformer_config_from_args +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec + + +stimer = StragglerDetector() + +def count_parameters_in_layer(model, layer_name): + num_params = 0 + for name, param in model.named_parameters(): + if layer_name in name: + num_params += param.numel() + print_rank_0(f" - {name}: {param.numel()}") + return num_params + + +def model_provider(pre_process=True, post_process=True) -> MambaModel: + """Builds the model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + MambaModel: The returned model + """ + args = get_args() + + print_rank_0('building Mamba model ...') + config = core_transformer_config_from_args(get_args()) + + assert args.use_legacy_models == False, "Mamba only supported in Mcore!" + + if args.spec is not None: + mamba_stack_spec = import_module(args.spec) + else: + raise("You must provide a valid Mamba layer spec!") + + model = MambaModel( + config=config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + hybrid_attention_ratio=args.hybrid_attention_ratio, + hybrid_mlp_ratio=args.hybrid_mlp_ratio, + hybrid_override_pattern=args.hybrid_override_pattern, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type + ) + + for l in range(model.decoder.num_layers_per_pipeline_rank): + layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.') + print_rank_0(f" == params layer {l}: {layer_params}") + + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + assert not loss[0].isnan(), ( + f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' + ) + + # Reduce loss for logging. + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def forward_step(data_iterator, model: MambaModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (MambaModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + with stimer: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def is_dataset_built_on_rank(): + return ( + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() + ) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_gpt_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + config = core_gpt_dataset_config_from_args(args) + + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + is_dataset_built_on_rank, + config + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain(train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) diff --git a/tools/checkpoint/hybrid_conversion.py b/tools/checkpoint/hybrid_conversion.py new file mode 100644 index 0000000000..737fac6b0f --- /dev/null +++ b/tools/checkpoint/hybrid_conversion.py @@ -0,0 +1,398 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +# Note (rwaleffe): This is a temporary file for hybrid mamba-transformer model checkpoint conversion. +# This functionality should be integrated with the megatron core checkpoint loader/saver. + + +import copy +import os +import re +import shutil +from collections import OrderedDict + +import torch +import argparse + + +tp_split_dim = { + 'word_embeddings.weight': 0, + 'norm.weight': -1, + 'final_norm.weight': -1, + 'output_layer.weight': 0, + # mamba1/2 + 'A_log': 0, + 'D': 0, + 'dt_bias': 0, + 'in_proj.weight': 0, + 'conv1d.weight': 0, + 'conv1d.bias': 0, + 'x_proj.weight': 1, + 'dt_proj.weight': 0, + 'dt_proj.bias': 0, + 'out_proj.weight': 1, + 'mixer.norm.weight': 0, + # mlp + 'linear_fc1.layer_norm_weight': -1, + 'linear_fc1.weight': 0, + 'linear_fc2.weight': 1, + # attention + 'self_attention.linear_proj.weight': 1, + 'self_attention.linear_qkv.layer_norm_weight': -1, + 'self_attention.linear_qkv.weight': 0, +} + + +def get_split_dim(tensor_name): + # norm.weight will match tensor_name of mixer.norm.weight and norm.weight, need to distinguish + if 'norm.weight' in tensor_name: + if 'mixer.norm.weight' in tensor_name: + return tp_split_dim['mixer.norm.weight'] + else: + return tp_split_dim['norm.weight'] + + for key in tp_split_dim.keys(): + if key in tensor_name: + return tp_split_dim[key] + raise Exception("Unknown tensor name {}".format(tensor_name)) + + +def combine_tp_tensors(params, key, dim, tensors): + tp_size = len(tensors) + + if 'mixer.in_proj.weight' in key and params.mamba_version == 1: + xs = []; zs = [] + for tensor in tensors: + x, z = torch.split(tensor, [params.mamba_d_inner//tp_size, + params.mamba_d_inner//tp_size], dim=dim) + xs.append(x); zs.append(z) + return torch.cat([torch.cat(xs, dim=dim), torch.cat(zs, dim=dim)], dim=dim) + + elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: + xs = []; zs = []; Bs = []; Cs = []; dts = [] + for tensor in tensors: + x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner // tp_size, + params.mamba_d_inner // tp_size, + (params.mamba2_n_groups // tp_size) * args.mamba_d_state, + (params.mamba2_n_groups // tp_size) * args.mamba_d_state, + params.mamba2_n_heads // tp_size], dim=dim) + xs.append(x); zs.append(z); Bs.append(B); Cs.append(C); dts.append(dt) + + for ii in range(len(Bs)): + Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-1])) + Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-1])) + B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim) + x = torch.cat(xs, dim=dim); z = torch.cat(zs, dim=dim); dt = torch.cat(dts, dim=dim) + + return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim) + + elif 'mixer.conv1d' in key and params.mamba_version == 2: + xs = []; Bs = []; Cs = [] + for tensor in tensors: + x, B, C = torch.split(tensor, [params.mamba_d_inner//tp_size, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state], dim=dim) + xs.append(x); Bs.append(B); Cs.append(C) + + for ii in range(len(Bs)): + if 'weight' in key: + Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-2], Bs[ii].shape[-1])) + Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-2], Cs[ii].shape[-1])) + elif 'bias' in key: + Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state)) + Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state)) + else: + raise Exception("Unknown key") + B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim) + x = torch.cat(xs, dim=dim) + + return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim) + + else: + return torch.cat(tensors, dim=dim) + + +def split_tensor_for_tp(params, key, dim, tensor): + tp_size = params.target_tp_size + tensor_sliced = [] + + if 'mixer.in_proj.weight' in key and params.mamba_version == 1: + x, z = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + z_sliced = torch.chunk(z, tp_size, dim=dim) + for (x, z) in zip(x_sliced, z_sliced): + tensor_sliced.append(torch.cat((x, z), dim=dim)) + + elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: + x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_heads], dim=dim) + B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-1])) + C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-1])) + + B_sliced = torch.chunk(B, tp_size, dim=dim) + C_sliced = torch.chunk(C, tp_size, dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + z_sliced = torch.chunk(z, tp_size, dim=dim) + dt_sliced = torch.chunk(dt, tp_size, dim=dim) + + tensor_sliced = [] + for (x, z, B, C, dt) in zip(x_sliced, z_sliced, B_sliced, C_sliced, dt_sliced): + tensor_sliced.append(torch.cat((x, z, B.flatten(0, 1), C.flatten(0, 1), dt), dim=dim)) + + elif 'mixer.conv1d' in key and params.mamba_version == 2: + x, B, C = torch.split(tensor, [params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state], dim=dim) + if 'weight' in key: + B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-2], B.shape[-1])) + C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-2], C.shape[-1])) + elif 'bias' in key: + B = torch.reshape(B, (-1, params.mamba_d_state)) + C = torch.reshape(C, (-1, params.mamba_d_state)) + else: + raise Exception("Unknown key") + + B_sliced = torch.chunk(B, tp_size, dim=dim) + C_sliced = torch.chunk(C, tp_size, dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + + tensor_sliced = [] + for (x, B, C) in zip(x_sliced, B_sliced, C_sliced): + tensor_sliced.append(torch.cat((x, B.flatten(0, 1), C.flatten(0, 1)), dim=dim)) + + else: + tensor_sliced = torch.chunk(tensor, tp_size, dim=dim) + + return tensor_sliced + + +def finalize_checkpoint(sample_model, model, params, verbose=False): + # make sure the rest of the checkpoint is how we want it from the original (i.e., other than the 'model') + reset_iterations = params.reset_iterations + + # checkpoint 'args' + model['args'] = copy.deepcopy(sample_model['args']) + model['args'].tensor_model_parallel_size = params.target_tp_size + model['args'].pipeline_model_parallel_size = params.target_pp_size + if reset_iterations: + model['args'].iteration = 0 + model['args'].consumed_valid_samples = 0 + model['args'].consumed_train_samples = 0 + model['args'].train_iters = 0 + model['args'].train_samples = 0 + + # checkpoint 'checkpoint_version' + model['checkpoint_version'] = copy.deepcopy(sample_model['checkpoint_version']) + + # checkpoint 'iteration' + model['iteration'] = copy.deepcopy(sample_model['iteration']) + if reset_iterations: + model['iteration'] = 0 + + # checkpoint 'optimizer' + # ignore + + # checkpoint 'opt_param_scheduler' + if 'opt_param_scheduler' in sample_model.keys(): + model['opt_param_scheduler'] = copy.deepcopy(sample_model['opt_param_scheduler']) + + # checkpoint 'rng_state' + model['rng_state'] = copy.deepcopy(sample_model['rng_state']) + + # report on argument difference + if verbose: + original_args = sample_model['args'].__dict__ + final_args = model['args'].__dict__ + for key in original_args: + if key in final_args: + if final_args[key] != original_args[key]: + print("KEY MISMATCH: {}".format(key)) + print("\toriginal: {}\n\tfinal: {}".format(original_args[key], final_args[key])) + else: + print("KEY MISSING from final: {}, value {}".format(key, original_args[key])) + print("") + for key in final_args: + if key not in original_args: + print("KEY ADDED to final: {}, value {}".format(key, final_args[key])) + + return model + + +def main(args): + print("\n====RUNNING CHECKPOINT CONVERSION====\n") + + args.mamba_d_inner = args.d_model * 2 + args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim + + # get the latest iteration + tracker_filename = os.path.join(args.load_dir, 'latest_checkpointed_iteration.txt') + with open(tracker_filename, 'r') as f: + metastring = f.read().strip() + try: + iteration = int(metastring) + except ValueError: + raise Exception("") + out_iteration = iteration if not args.reset_iterations else 0 + + # get model directory and model parallel ranks + input_model_dir = os.path.join(args.load_dir, 'iter_{:07d}'.format(iteration)) + input_sub_models = os.listdir(input_model_dir) + # input_sub_models = sorted(input_sub_models, key=lambda x: int(re.search(r'\d+', x).group())) + + # load one of the model parallel ranks to get arguments + sample_model_file = os.path.join(input_model_dir, input_sub_models[0], "model_optim_rng.pt") + sample_model = torch.load(sample_model_file) + print(f"Sample model {sample_model_file} is loaded.\n") + + # input tensor and pipeline parallel size + input_tp_rank = sample_model['args'].tensor_model_parallel_size + input_pp_rank = sample_model['args'].pipeline_model_parallel_size + num_layers_per_pipeline_rank = sample_model['args'].num_layers // input_pp_rank + + # construct full model + full_model = OrderedDict() + for pp in range(input_pp_rank): + print("[INFO] Processing input pipeline rank {}".format(pp)) + tp_models = [] + for tp in range(input_tp_rank): + dir_name = "mp_rank_{:02d}".format(tp) + if input_pp_rank > 1: + dir_name += "_{:03d}".format(pp) + model_file = os.path.join(input_model_dir, dir_name, "model_optim_rng.pt") + + tp_models.append(torch.load(model_file)) + print(f"Model {model_file} is loaded.") + + if input_tp_rank > 1: + combined_tp_model = OrderedDict() + for ii, (key, original_tensor) in enumerate(tp_models[0]['model'].items()): + if "_extra_state" in key: + combined_tp_model[key] = original_tensor + continue + + split_dim = get_split_dim(key) + original_shape = list(original_tensor.shape) + combined_shape = copy.deepcopy(original_shape) + combined_shape[split_dim] *= input_tp_rank + # print("{}, {}, {}".format(ii, key, split_dim)) + + if split_dim != -1: + # slice together model + # print("\tshape mismatch: original {}, combined {}".format(original_shape, combined_shape)) + combined_tensor = combine_tp_tensors(args, key, split_dim, + [tp_models[jj]['model'][key].cpu() for jj in range(input_tp_rank)]) + combined_tp_model[key] = combined_tensor + else: + # copy model + combined_tp_model[key] = original_tensor + else: + combined_tp_model = tp_models[0]['model'] + # print("Combined tp model: {}".format(combined_tp_model.keys())) + + for ii, (key, original_tensor) in enumerate(combined_tp_model.items()): + try: + layer_num = int(re.findall(r'\d+', key)[0]) + new_key = key.replace(str(layer_num), str(layer_num + pp*num_layers_per_pipeline_rank), 1) + except: + new_key = key + full_model[new_key] = original_tensor + # print("Combined model: {}".format(full_model.keys())) + print("\n[INFO] Loaded combined model\n") + + # sort by layer + # full_model_sorted = dict(sorted(people.items(), key=lambda item: item[1])) + + # create new split model + pp_offset = 0 + num_layers_per_pipeline_rank = sample_model['args'].num_layers // args.target_pp_size + + for pp in range(args.target_pp_size): + print("[INFO] Processing output pipeline rank {}".format(pp)) + tp_models = [] + for ii in range(args.target_tp_size): + tp_models.append({'model': OrderedDict()}) + + for ii, (key, original_tensor) in enumerate(full_model.items()): + try: + layer_num = int(re.findall(r'\d+', key)[0]) + if layer_num >= num_layers_per_pipeline_rank * (pp+1): + break + new_key = key.replace(str(layer_num), str(layer_num - (pp * num_layers_per_pipeline_rank)), 1) + except: + new_key = key + + if ii < pp_offset: + continue + else: + pp_offset += 1 + + if "_extra_state" in new_key: + # copy + for jj in range(args.target_tp_size): + tp_models[jj]['model'][new_key] = original_tensor + continue + + split_dim = get_split_dim(new_key) + original_shape = list(original_tensor.shape) + v0 = original_shape[split_dim] + split_size = v0 // args.target_tp_size + split_shape = copy.deepcopy(original_shape) + split_shape[split_dim] = split_size + # print("{}, {}, {}".format(ii, new_key, split_dim)) + + if split_dim != -1: + # split model + # print("\tshape mismatch: original {}, combined {}".format(original_shape, split_shape)) + tensor_sliced = split_tensor_for_tp(args, new_key, split_dim, original_tensor) + for jj in range(args.target_tp_size): + tp_models[jj]['model'][new_key] = tensor_sliced[jj] + else: + # copy model + for jj in range(args.target_tp_size): + tp_models[jj]['model'][new_key] = original_tensor + # print(tp_models[0]['model'].keys()) + + for tp in range(args.target_tp_size): + dir_name = "mp_rank_{:02d}".format(tp) + if args.target_pp_size > 1: + dir_name += "_{:03d}".format(pp) + + model = finalize_checkpoint(sample_model, tp_models[tp], args, verbose=False) + + save_dir = os.path.join(args.save_dir, 'iter_{:07d}'.format(out_iteration), dir_name) + os.makedirs(save_dir, exist_ok=True) + model_file = os.path.join(save_dir, "model_optim_rng.pt") + torch.save(model, model_file) + print(f"Model {model_file} is saved.") + + # shutil.copyfile(tracker_filename, os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt')) + tracker_filename = os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt') + with open(tracker_filename, 'w') as f: + f.write(str(out_iteration)) + + +if __name__ == "__main__": + # example run command: + # python hybrid_conversion.py + # --load-dir mamba2-840m-test/checkpoints/ + # --save-dir mamba2-840m-test-conversion/checkpoints/ + # --target-pp-size 1 + # --target-tp-size 1 + + parser = argparse.ArgumentParser() + parser.add_argument('--load-dir', type=str) + parser.add_argument('--save-dir', type=str) + parser.add_argument('--target-tp-size', type=int, default=1) + parser.add_argument('--target-pp-size', type=int, default=1) + parser.add_argument('--reset-iterations', action='store_true') + + parser.add_argument('--d-model', type=int, default=4096) + parser.add_argument('--mamba-version', type=int, default=2) + parser.add_argument('--mamba-d-state', type=int, default=128) + parser.add_argument('--mamba2-n-groups', type=int, default=8) + parser.add_argument('--mamba2-head-dim', type=int, default=64) + + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/tools/run_mamba_text_generation_server.py b/tools/run_mamba_text_generation_server.py new file mode 100644 index 0000000000..844d018055 --- /dev/null +++ b/tools/run_mamba_text_generation_server.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Sample Generate Mamba""" +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), + os.path.pardir))) +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.core import mpu +from megatron.training.checkpointing import load_checkpoint +from megatron.training.initialize import initialize_megatron +from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.transformer.spec_utils import import_module +from megatron.training import get_model +from megatron.training.arguments import core_transformer_config_from_args +from megatron.inference.text_generation_server import MegatronServer +from megatron.inference.text_generation import generate_and_post_process +from megatron.inference.text_generation import beam_search_and_post_process + +import torch + +def count_parameters_in_layer(model, layer_name): + num_params = 0 + for name, param in model.named_parameters(): + if layer_name in name: + num_params += param.numel() + print_rank_0(f" - {name}: {param.numel()}") + return num_params + +# Taken from pretrain_mamba.py +def model_provider(pre_process=True, post_process=True) -> MambaModel: + """Builds the model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + MambaModel: The returned model + """ + args = get_args() + + print_rank_0('building Mamba model ...') + config = core_transformer_config_from_args(get_args()) + + assert args.use_legacy_models == False, "Mamba only supported in Mcore!" + + if args.spec is not None: + mamba_stack_spec = import_module(args.spec) + else: + raise("You must provide a valid Mamba layer spec!") + + model = MambaModel( + config=config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + hybrid_attention_ratio=args.hybrid_attention_ratio, + hybrid_mlp_ratio=args.hybrid_mlp_ratio, + hybrid_override_pattern=args.hybrid_override_pattern, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type + ) + + for l in range(model.decoder.num_layers_per_pipeline_rank): + layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.') + print_rank_0(f" == params layer {l}: {layer_params}") + + return model + +def add_text_generate_args(parser): + group = parser.add_argument_group(title='text generation') + group.add_argument("--port", type=int, default=5000, + help='port for text generation server to run on') + return parser + + +if __name__ == "__main__": + initialize_megatron(extra_args_provider=add_text_generate_args, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer', + 'no_load_rng': True, + 'no_load_optim': True}) + + args = get_args() + if args.num_layers_per_virtual_pipeline_stage is not None: + print("Interleaved pipeline schedule is not yet supported for text generation.") + exit() + print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text " + "generation.") + args.exit_on_missing_checkpoint = True + # Set up model and load checkpoint + model = get_model(model_provider, wrap_with_ddp=False) + + if args.load is not None: + _ = load_checkpoint(model, None, None) + + assert len(model) == 1, "Above condition should have caught this" + model = model[0] + if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: + server = MegatronServer(model) + server.run("0.0.0.0",port=args.port) + + while True: + choice = torch.tensor(1, dtype=torch.long, device='cuda') + torch.distributed.broadcast(choice, 0) + if choice.item() == 0: + try: + generate_and_post_process(model) + except ValueError as ve: + pass + elif choice.item() == 1: + try: + beam_search_and_post_process(model) + except ValueError as ve: + pass