From b1e36c46e03b192abd1633c31a90549387f22ab6 Mon Sep 17 00:00:00 2001 From: Tuomas Rintamaki Date: Wed, 14 Aug 2024 17:58:13 -0700 Subject: [PATCH] ADLR/megatron-lm!1402 - LLaVA expanded data processing --- .gitlab/stages/01.tests.yml | 14 +- examples/multimodal/run_text_generation.py | 10 +- examples/multimodal/train.py | 31 +-- .../core/models/multimodal/llava_model.py | 237 ++++++++++++++++-- megatron/core/models/vision/clip_vit_model.py | 8 + pretrain_vlm.py | 85 ++++--- .../golden_values.json | 2 +- .../golden_values.json | 2 +- .../golden_values.json | 2 +- tests/unit_tests/models/test_llava_model.py | 175 +++++++++++-- 10 files changed, 455 insertions(+), 111 deletions(-) diff --git a/.gitlab/stages/01.tests.yml b/.gitlab/stages/01.tests.yml index ae26823266..ea9076ce35 100644 --- a/.gitlab/stages/01.tests.yml +++ b/.gitlab/stages/01.tests.yml @@ -76,7 +76,7 @@ unit_tests: parallel: matrix: - TAG: latest - - TAG: db5c60ae3fe5247f16ec0536bbf41ee5c7fb9c4a + - TAG: a5efe829b1d34c691f0a7a5286e271b4f9c86b2a tags: [8xL40S] variables: GIT_STRATEGY: clone @@ -89,11 +89,15 @@ unit_tests: cp -r tests/ /opt/megatron-lm fi script: - - | - cd /opt/megatron-lm + - | + cd /opt/megatron-lm for i in $(seq $UNIT_TEST_REPEAT); do SEED=$((RANDOM % 9000 + 1000)); - timeout ${UNIT_TEST_TIMEOUT}m torchrun --nproc_per_node=8 -m pytest --random-order --random-order-seed ${SEED} -xvs --cov-report=term --cov-report=html --cov=megatron/core --no-cov-on-fail `$([[ $TAG != latest ]] && echo -m 'not internal')` tests/unit_tests + SKIPPED=() + if [[ $TAG != latest ]]; then + SKIPPED+=(-m "not internal") + fi + timeout ${UNIT_TEST_TIMEOUT}m torchrun --nproc_per_node=8 -m pytest --random-order --random-order-seed ${SEED} -xvs --cov-report=term --cov-report=html --cov=megatron/core --no-cov-on-fail "${SKIPPED[@]}" tests/unit_tests done artifacts: paths: @@ -143,7 +147,7 @@ secret_detection: - apk add jq - /analyzer run - | - if [[ $(cat gl-secret-detection-report.json | jq '.vulnerabilities | length > 0') == true ]]; then + if [[ $(cat gl-secret-detection-report.json | jq '.vulnerabilities | length > 0') == true ]]; then echo "Atleast one vulnerability has been found" cat gl-secret-detection-report.json | jq '.' exit 1 diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index 24a2e19186..961fc6c653 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -19,6 +19,7 @@ from torchvision.transforms import Compose, Resize, ToPILImage from train import add_multimodal_extra_args, get_image_token_count, model_provider +from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN_INDEX from megatron.inference.text_generation.api import generate_and_post_process from megatron.inference.text_generation.forward_step import ForwardStep from megatron.training import get_args, get_model, print_rank_0 @@ -282,7 +283,7 @@ def generate_samples(model): elif args.task in ("TextVQA", "MMMU"): output_name = "text" - generated = generation[len(prompt) + 1 :] + generated = generation[len(prompt):] output[output_name] = generated if args.task == "captioning": @@ -329,6 +330,13 @@ def __init__(self, images, num_image_tokens, model, max_batch_size, max_sequence self._images = images def _forward(self, tokens, position_ids, attention_mask): + # Add image token index to the front if it's not included in the prompt. Note: This will change in a future MR. + num_tokens = tokens.shape[1] + + if num_tokens > 1 and torch.sum(tokens == IMAGE_TOKEN_INDEX).item() == 0: + tokens = torch.cat([torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=tokens.dtype, device=tokens.device), tokens], dim=1) + position_ids = torch.arange(num_tokens, dtype=position_ids.dtype, device=position_ids.device) + return self.model( self._images, tokens, diff --git a/examples/multimodal/train.py b/examples/multimodal/train.py index a1eb8b2b26..56f2b0d741 100644 --- a/examples/multimodal/train.py +++ b/examples/multimodal/train.py @@ -85,7 +85,7 @@ def model_provider( vision_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size vision_projection_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size if args.encoder_tensor_model_parallel_size > 0: - vision_transformer_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size + vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size vision_projection_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules @@ -113,7 +113,6 @@ def model_provider( img_w=args.img_w, patch_dim=args.patch_dim, language_rotary_base=args.rotary_base, - img_embedding_idx=args.img_embedding_idx, ) model.freeze(freeze_language_model=args.freeze_LM, freeze_vision_model=args.freeze_ViT, freeze_vision_projection=False) @@ -171,10 +170,6 @@ def get_batch(data_iterator): question_length=prompt_len) torch.cuda.nvtx.range_pop() - loss_mask, labels, attention_mask = _preprocess_data_for_llava(loss_mask, labels, attention_mask) - - tokens = tokens[:, 1:] # drop image index token - return tokens, labels, loss_mask, attention_mask, position_ids, img_raw @@ -191,24 +186,6 @@ def get_image_token_count(): return num_image_tokens -def _preprocess_data_for_llava(loss_mask, labels, attention_mask): - """Preprocess data sample to the format expected by a LLaVA model.""" - num_image_tokens = get_image_token_count() - - batch_size = loss_mask.shape[0] - - loss_mask2 = torch.cat( - [torch.zeros(batch_size, num_image_tokens - 1, dtype=torch.float32, device=loss_mask.device), loss_mask], dim=1 - ) - labels2 = torch.cat([torch.zeros(batch_size, num_image_tokens - 1, dtype=torch.int64, device=labels.device), labels], dim=1) - - full_seq_length = len(labels2[0]) - attention_mask2 = torch.tril(torch.ones((1, 1, full_seq_length, full_seq_length), device=attention_mask.device)) - attention_mask2 = attention_mask2 < 0.5 - - return loss_mask2, labels2, attention_mask2 - - def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, @@ -312,7 +289,7 @@ def forward_step(data_iterator, model: LLaVAModel): tokens, labels, loss_mask, attention_mask, position_ids, images = get_batch(data_iterator) timers('batch-generator').stop() - output_tensor = model(images, tokens, position_ids, attention_mask, labels=labels) + output_tensor, loss_mask = model(images, tokens, position_ids, attention_mask, labels, loss_mask) return output_tensor, partial(loss_func, loss_mask) @@ -332,10 +309,6 @@ def add_multimodal_extra_args(parser): group.add_argument("--disable-vision-class-token", action="store_true", default=False) group.add_argument("--allow-missing-vision-projection-checkpoint", action="store_true", default=False) group.add_argument("--use-te", action="store_true", default=False) - group.add_argument("--img-embedding-idx", type=int, default=0, - help='Llava specific parameter. Defines at which index' - 'in the language_embedding tensor the image_embeddings' - 'should be inserted') group.add_argument("--dataloader-save", type=str, default=None, help="Energon dataloader state save path") return parser diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index 6acc92630c..f15418e4b6 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -6,15 +6,17 @@ import torch -from megatron.core import InferenceParams, parallel_state +from megatron.core import InferenceParams from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.models.gpt import GPTModel -from megatron.core.models.vision.clip_vit_model import CLIPViTModel +from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_image_sequence_length from megatron.core.models.vision.multimodal_projector import MultimodalProjector from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import make_viewless_tensor + +IMAGE_TOKEN_INDEX = -200 # ID for images in the input sequence. +IGNORE_INDEX = -100 # ID for labels that should be ignored. # Note: This is under development and may be missing features. @@ -45,7 +47,6 @@ class LLaVAModel(MegatronModule): img_h (int): The height of each image that the ViT will see. img_w (int): The width of each image that the ViT will see. patch_dim (int): The size of each patch side. - img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be inserted. Defaults to 0. """ def __init__( @@ -72,7 +73,6 @@ def __init__( img_w: int = 336, patch_dim: int = 14, language_rotary_base: int = 10000, - img_embedding_idx: int = 0, ) -> None: super().__init__(config=language_transformer_config) @@ -87,7 +87,6 @@ def __init__( self.post_process = post_process self.add_encoder = add_encoder self.add_decoder = add_decoder - self.img_embedding_idx = img_embedding_idx self.encoder_hidden_state = None self.vision_model = None @@ -114,12 +113,14 @@ def __init__( self.language_model.share_embeddings_and_output_weights ) + class_token_len = 1 if self.add_encoder: self.vision_model = CLIPViTModel( vision_transformer_config, vision_transformer_layer_spec, img_h=img_h, img_w=img_w, + class_token_len=class_token_len, patch_dim=patch_dim, ) self._drop_vision_class_token = drop_vision_class_token @@ -142,6 +143,10 @@ def __init__( partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names) ) + self._img_seq_len = get_image_sequence_length( + img_h, img_w, patch_dim, not drop_vision_class_token, class_token_len + ) + def shared_embedding_or_output_weight(self): """This is a convenience method to surface the language model's word embeddings, which is necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" @@ -190,6 +195,172 @@ def freeze( for param in module.parameters(): param.requires_grad = False + def _preprocess_data( + self, + image_embeddings, + language_embeddings, + input_ids, + loss_mask, + labels, + use_inference_kv_cache, + image_token_index, + ): + """Preprocess input data before input to language model. + + This function is adopted from + https://github.com/huggingface/transformers/blob/85817d98fb60977c97e3014196a462b732d2ed1a/src/transformers/models/llava_next/modeling_llava_next.py#L409 + for our input data conventions. + + image_token_index = -200 indicates the image position in the input_ids = [0, 1, -200, 2, 3] and labels = [1, -200, 2, 3, 4], for example. + We want to replace the image position (-200) with image_embeddings and return the following: + - final_embeddings = [0, 1, image_embeddings, 2, 3], + - final_labels = [1, -100, 2, 3, 4] + - final_loss_mask = [1, 0, 0, 1, 1] + + This function also handles the case where the input does not contain an image (text-only sample). + + If pipeline parallelism is not used, then self.pre_process and self.post_process are both True and we update both + input embeddings, labels and loss masks (if available). + + If pipeline parallelism is used, then we do the following + - the first language model chunk has self.pre_process = True and self.post_process = False. We update input embeddings. + - the middle language model chunk(s) has self.pre_process = False and self.post_process = False. We don't need to update anything. + - the last language model chunk has self.pre_process = False and self.post_process = True. We update labels and loss mask. + + TODO: This function should adjust the attention mask too. Currently, we assume the language model uses a causal mask. + + Returns: + final_embedding (torch.Tensor): image and text embeddings concated [combined_seq_len, b, h]. + final_labels (torch.Tensor): labels for image and text positions [b, combined_seq_len]. + final_loss_mask (torch.Tensor): loss mask for image and text positions [b, combined_seq_len]. + """ + assert self.add_decoder, "input text preprocessing is only needed for the language model" + + # No pre- or postprocessing needed. With pipeline parallel > 2, this means a chunk in the middle of the model. + if not self.pre_process and not self.post_process: + return language_embeddings, loss_mask, labels + + # If using the inference KV cache, the image tokens are already computed. + if use_inference_kv_cache: + return language_embeddings, loss_mask, labels + + img_seq_len = ( + self._img_seq_len - 1 + ) # Adjust by -1 to account for the removed image token index. + batch_size, text_seq_len = input_ids.shape + + has_labels = labels is not None + if has_labels: + assert ( + labels.shape == loss_mask.shape + ), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}" + + with torch.no_grad(): + image_token_mask = input_ids == image_token_index + num_image_tokens = torch.sum(image_token_mask, dim=-1) + + max_seq_len = (num_image_tokens.max() * img_seq_len) + text_seq_len + batch_indices, non_image_indices = torch.where(input_ids != image_token_index) + + # New position ids for the text tokens, shifted by the image sequence length. + # E.g. for input_ids = [-200, 1, 2, 3] and img_seq_len = 576, we get new_position_ids = [576, 577, 578, 579]. + # text_position_ids are then [577, 578, 579]. + # +1 is needed here for the cumulative sum. -1 is adjusting for zero-based indexing. + new_position_ids = torch.cumsum((image_token_mask * img_seq_len + 1), dim=-1) - 1 + text_position_ids = new_position_ids[batch_indices, non_image_indices] + + # Repeat the same for labels, which have the image token index shifted to left by one. + # An exception is an input sequence starting with an image token in which case + # the image token is not present in labels so we correct for it. + if has_labels: + edge = input_ids[:, 0] == image_token_index + label_image_token_mask = labels == image_token_index + label_batch_indices, label_non_image_indices = torch.where( + labels != image_token_index + ) + + new_label_position_ids = ( + torch.cumsum((label_image_token_mask * img_seq_len + 1), dim=-1) - 1 + ) + # If the input sequence starts with an image token, then that image token is not present in the labels + # and we need to shift the label position ids by the image sequence length. + new_label_position_ids[edge] += img_seq_len + label_text_position_ids = new_label_position_ids[ + label_batch_indices, label_non_image_indices + ] + + # Initialize output tensors. + final_embedding = None + if self.pre_process: + embed_dim = language_embeddings.shape[-1] + final_embedding = torch.zeros( + batch_size, + max_seq_len, + embed_dim, + dtype=image_embeddings.dtype, + device=image_embeddings.device, + ) + + final_labels, final_loss_mask = None, None + if has_labels: + final_labels = torch.full( + (batch_size, max_seq_len), IGNORE_INDEX, dtype=labels.dtype, device=labels.device + ) + final_loss_mask = torch.full( + (batch_size, max_seq_len), 0, dtype=loss_mask.dtype, device=loss_mask.device + ) + + # Put text embeddings to the text positions in the result tensor. + if self.pre_process: + final_embedding[batch_indices, text_position_ids] = language_embeddings[ + batch_indices, non_image_indices + ] + + # Put text labels and loss mask to the text positions. + if has_labels: + final_labels[label_batch_indices, label_text_position_ids] = labels[ + label_batch_indices, label_non_image_indices + ] + final_loss_mask[batch_indices, text_position_ids] = loss_mask[ + batch_indices, non_image_indices + ] + + with torch.no_grad(): + # Create a mask for the image embedding positions. + images_mask = torch.full( + (batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device + ) + images_mask[batch_indices, text_position_ids] = ( + False # No images in the text positions. + ) + # Samples can have different amount of images tokens. new_position_ids[:, -1] gives the last text position id for each sample. + # Padding is needed when the number of image tokens differs. Compute the number of padding tokens on the right for each sample. + padding = max_seq_len - 1 - new_position_ids[:, -1] + # Mark the padding tokens on the right as False in the images mask. -1 adjusts cumulative sum to be zero-based. + images_mask &= images_mask.cumsum(dim=-1) - 1 >= padding[:, None] + + if self.pre_process: + final_embedding[images_mask] = image_embeddings.reshape(-1, embed_dim).contiguous() + + if has_labels: + # Loss mask the image positions. + final_loss_mask[images_mask] = 0 + + # Loss mask last text position just before an image so that text token does not need to predict the first image token. + batch_image_indices, image_indices = torch.where(image_token_mask) + text_before_image_indices = torch.maximum(image_indices - 1, torch.tensor(0)) + final_loss_mask[batch_image_indices, text_before_image_indices] = 0 + + if final_embedding is not None and has_labels: + assert ( + final_embedding.shape[:2] == final_labels.shape == final_loss_mask.shape + ), "unexpected shapes after data preprocessing" + + if final_embedding is not None: + final_embedding = final_embedding.transpose(1, 0).contiguous() + + return final_embedding, final_labels, final_loss_mask + def forward( self, images: torch.Tensor, @@ -197,7 +368,9 @@ def forward( position_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor = None, + loss_mask: torch.Tensor = None, inference_params: InferenceParams = None, + image_token_index: int = IMAGE_TOKEN_INDEX, ) -> torch.Tensor: """Forward function of the LLaVA model. @@ -205,11 +378,15 @@ def forward( images (torch.Tensor): input image of shape [batch, img_h, img_w]. input_ids (torch.Tensor): input text ids [batch, text_seq_len]. position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. - attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len]. + attention_mask (torch.Tensor): Attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len]. labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. + loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len]. inference_params (InferenceParams): Inference-time parameters including KV cache. + image_token_index (int): ID for input images. + Returns: output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. + loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s]. """ use_inference_kv_cache = ( inference_params is not None @@ -226,6 +403,7 @@ def forward( image_embeddings = image_embeddings.permute( 1, 0, 2 ).contiguous() # [img_seq_len, b, h_vision] + # map vision model output size to language model input size. image_embeddings = self.vision_projection( image_embeddings @@ -241,38 +419,45 @@ def forward( image_embeddings = self.encoder_hidden_state if not self.add_decoder: - return image_embeddings + return image_embeddings, loss_mask + language_embeddings = None if self.pre_process: + input_ids_text = input_ids.clone() + input_ids_text[input_ids_text == image_token_index] = 0 + # Note: This adds absolute position embedding but not RoPE. Each image is counted as one position. + # RoPE is added in language_model forward call. Each image embedding is one position. language_embeddings = self.language_model.embedding( - input_ids=input_ids, position_ids=position_ids + input_ids=input_ids_text, position_ids=position_ids ) # [text_seq_len, b, h_language] - - # If running inference, we can skip image token computation if they were computed already earlier for this sample. - if use_inference_kv_cache: - combined_embeddings = language_embeddings - else: - combined_embeddings = torch.cat( - [ - language_embeddings[: self.img_embedding_idx], - image_embeddings, - language_embeddings[self.img_embedding_idx :], - ], - dim=0, - ) # [combined_seq_len, b, h_language] - else: - combined_embeddings = None + language_embeddings = language_embeddings.transpose( + 1, 0 + ).contiguous() # [b, text_seq_len, h_language] + + # Preprocess input, labels and loss mask. + combined_embeddings, new_labels, new_loss_mask = self._preprocess_data( + image_embeddings, + language_embeddings, + input_ids, + loss_mask, + labels, + use_inference_kv_cache, + image_token_index, + ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] output = self.language_model( input_ids=None, position_ids=None, attention_mask=attention_mask, decoder_input=combined_embeddings, - labels=labels, + labels=new_labels, inference_params=inference_params, ) - return output + if labels is None or loss_mask is None: + return output + + return output, new_loss_mask def _load_state_dict_hook_ignore_param_names( diff --git a/megatron/core/models/vision/clip_vit_model.py b/megatron/core/models/vision/clip_vit_model.py index 2b7e281873..6a37883109 100644 --- a/megatron/core/models/vision/clip_vit_model.py +++ b/megatron/core/models/vision/clip_vit_model.py @@ -150,3 +150,11 @@ def forward( x = x.contiguous() return x + + +def get_image_sequence_length(img_h, img_w, patch_dim, add_class_token, class_token_len): + """Get image sequence length given image size, patch size, and class token.""" + num_patches_per_dim_h = img_h // patch_dim + num_patches_per_dim_w = img_w // patch_dim + num_patches = num_patches_per_dim_h * num_patches_per_dim_w + return num_patches + (class_token_len if add_class_token else 0) diff --git a/pretrain_vlm.py b/pretrain_vlm.py index 334f1f8a0d..678e2ffc4f 100644 --- a/pretrain_vlm.py +++ b/pretrain_vlm.py @@ -2,18 +2,22 @@ """Pretrain vision language model.""" from copy import deepcopy from functools import partial -from types import SimpleNamespace import torch from megatron.core import parallel_state, tensor_parallel from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.gpt_dataset import MockGPTLowLevelDataset from megatron.core.datasets.multimodal_dataset import MockMultimodalDataset, MultimodalDatasetConfig from megatron.core.enums import ModelType -from megatron.core.models.multimodal.llava_model import LLaVAModel -from megatron.core.models.multimodal.llava_spec import decoder_model_with_transformer_engine_default_spec, decoder_model_with_local_default_spec -from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec, get_vit_layer_with_local_spec +from megatron.core.models.multimodal.llava_model import LLaVAModel, IMAGE_TOKEN_INDEX +from megatron.core.models.multimodal.llava_spec import ( + decoder_model_with_transformer_engine_default_spec, + decoder_model_with_local_default_spec, +) +from megatron.core.models.vision.vit_layer_specs import ( + get_vit_layer_with_transformer_engine_spec, + get_vit_layer_with_local_spec, +) from megatron.core.transformer.spec_utils import import_module from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args @@ -32,8 +36,8 @@ def get_num_image_tokens(): def model_provider( - pre_process=True, post_process=True, add_encoder=True, add_decoder=True, - parallel_output=True) -> LLaVAModel: + pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True +) -> LLaVAModel: """Builds the model. Note: currently, only LLaVA model is supported. Follow-up changes will make this configurable. @@ -84,12 +88,22 @@ def model_provider( vision_projection_config = deepcopy(language_transformer_config) if args.encoder_pipeline_model_parallel_size > 0: - assert args.encoder_pipeline_model_parallel_size == 1, "ViT can only live on 1 pipeline stage." - vision_transformer_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size - vision_projection_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size + assert ( + args.encoder_pipeline_model_parallel_size == 1 + ), "ViT can only live on 1 pipeline stage." + vision_transformer_config.pipeline_model_parallel_size = ( + args.encoder_pipeline_model_parallel_size + ) + vision_projection_config.pipeline_model_parallel_size = ( + args.encoder_pipeline_model_parallel_size + ) if args.encoder_tensor_model_parallel_size > 0: - vision_transformer_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size - vision_projection_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size + vision_transformer_config.tensor_model_parallel_size = ( + args.encoder_tensor_model_parallel_size + ) + vision_projection_config.tensor_model_parallel_size = ( + args.encoder_tensor_model_parallel_size + ) vision_projection_modules = deepcopy(language_transformer_layer_spec.submodules.mlp.submodules) @@ -133,7 +147,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): config = MultimodalDatasetConfig( random_seed=args.seed, split=args.split, - sequence_length=args.decoder_seq_length-args.seq_length, + sequence_length=args.decoder_seq_length - args.seq_length, tokenizer=get_tokenizer(), reset_position_ids=args.reset_position_ids, reset_attention_mask=args.reset_attention_mask, @@ -146,8 +160,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): print_rank_0("> building train, validation, and test datasets for multimodal ...") train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( - MockMultimodalDataset, train_val_test_num_samples, - lambda: parallel_state.get_tensor_model_parallel_rank() == 0, config + MockMultimodalDataset, + train_val_test_num_samples, + lambda: parallel_state.get_tensor_model_parallel_rank() == 0, + config, ).build() print_rank_0("> finished creating multimodal datasets ...") @@ -166,21 +182,27 @@ def _preprocess_data_for_llava(data): Returns: data (dict): Processed data sample suitable for the model. """ - args = get_args() - - # TODO: Move these to multimodal spec (added in a separate code change). - num_image_tokens = get_num_image_tokens() - + # Prepend image token index to tokens. + data["tokens"] = torch.cat( + [ + IMAGE_TOKEN_INDEX + * torch.ones(1, dtype=data["tokens"].dtype, device=data["tokens"].device), + data["tokens"], + ] + ) + # Prepend labels accordingly. + data["labels"] = torch.cat([data["tokens"][1].unsqueeze(0), data["labels"]]) + # Zero loss mask for the image token index. data["loss_mask"] = torch.cat( - [torch.zeros(num_image_tokens, dtype=torch.float32), data["loss_mask"]] + [ + torch.zeros(1, dtype=data["loss_mask"].dtype, device=data["loss_mask"].device), + data["loss_mask"], + ] + ) + # Add one more position id. + data["position_ids"] = torch.cat( + [data["position_ids"], data["position_ids"][-1].unsqueeze(0) + 1] ) - data["labels"] = torch.cat([torch.zeros(num_image_tokens, dtype=torch.int64), data["labels"]]) - - full_seq_length = len(data["labels"]) - attention_mask = torch.tril(torch.ones((1, full_seq_length, full_seq_length))) - attention_mask = attention_mask < 0.5 - attention_mask[:, num_image_tokens:, num_image_tokens:] = data["attention_mask"] - data["attention_mask"] = attention_mask return data @@ -202,14 +224,13 @@ def get_batch(data_iterator): data_i = tensor_parallel.broadcast_data(["tokens", "position_ids", "labels"], data, torch.int64) data_f = tensor_parallel.broadcast_data(["image", "loss_mask"], data, torch.float32) - data_b = tensor_parallel.broadcast_data(["attention_mask"], data, torch.bool) tokens = data_i["tokens"].long() position_ids = data_i["position_ids"].long() labels = data_i["labels"].long() images = data_f["image"].float() loss_mask = data_f["loss_mask"].float() - attention_mask = data_b["attention_mask"].bool() + attention_mask = None # Use the attention mask type defined in layer spec. Typically no mask for the vision model and causal mask for the vision model. return tokens, position_ids, labels, images, loss_mask, attention_mask @@ -232,7 +253,9 @@ def forward_step(data_iterator, model: LLaVAModel): tokens, position_ids, labels, images, loss_mask, attention_mask = get_batch(data_iterator) timers('batch-generator').stop() - output_tensor = model(images, tokens, position_ids, attention_mask, labels=labels) + output_tensor, loss_mask = model( + images, tokens, position_ids, attention_mask, labels, loss_mask + ) return output_tensor, partial(loss_func, loss_mask) diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp1_pp1_dgx_a100_1N8G/golden_values.json b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp1_pp1_dgx_a100_1N8G/golden_values.json index 48ba344dc6..95613eb157 100644 --- a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp1_pp1_dgx_a100_1N8G/golden_values.json +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp1_pp1_dgx_a100_1N8G/golden_values.json @@ -1 +1 @@ -{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.13354, 9.1316, 9.12826, 9.11143, 9.05228, 9.04432, 8.98174, 8.93272, 8.88944, 8.78144]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3477550.0, 3584234.0, 3475077.0, 3382877.0, 3699618.0, 3478787.0, 3397764.0, 3453754.0, 3425474.0, 3585568.0]}, "iteration_timing_avg": 0.2253964705882353} +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.13455, 9.13251, 9.12855, 9.11268, 9.05516, 9.04352, 8.98424, 8.9352, 8.8928, 8.79364]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3478602.0, 3585025.0, 3475914.0, 3384266.0, 3700151.0, 3480265.0, 3398670.0, 3454930.0, 3426119.0, 3585909.0]}, "iteration_timing_avg": 0.2253964705882353} diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp2_pp3_dgx_a100_1N8G/golden_values.json b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp2_pp3_dgx_a100_1N8G/golden_values.json index 071b3f7536..9408e18a70 100644 --- a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp2_pp3_dgx_a100_1N8G/golden_values.json +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp2_pp3_dgx_a100_1N8G/golden_values.json @@ -1 +1 @@ -{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.16322, 9.16145, 9.15634, 9.13855, 9.08919, 9.07158, 9.01348, 8.96303, 8.91984, 8.81963]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3557155.0, 3663852.0, 3555196.0, 3462965.0, 3779960.0, 3558761.0, 3477375.0, 3533357.0, 3505070.0, 3665113.0]}, "iteration_timing_avg": 0.2253964705882353} \ No newline at end of file +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.16216, 9.16272, 9.15753, 9.14108, 9.09527, 9.07229, 9.01583, 8.96745, 8.92202, 8.83118]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3558559.0, 3664672.0, 3555664.0, 3463897.0, 3780688.0, 3560220.0, 3478422.0, 3535024.0, 3506032.0, 3666249.0]}, "iteration_timing_avg": 0.2253964705882353} diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_etp3_dgx_a100_1N7G/golden_values.json b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_etp3_dgx_a100_1N7G/golden_values.json index 4fb81ef651..261295666a 100644 --- a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_etp3_dgx_a100_1N7G/golden_values.json +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_etp3_dgx_a100_1N7G/golden_values.json @@ -1 +1 @@ -{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19896, 9.20165, 9.19473, 9.17429, 9.11918, 9.10248, 9.04068, 8.98319, 8.94029, 8.83684]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3717549.0, 3824075.0, 3714573.0, 3622935.0, 3939733.0, 3718925.0, 3637303.0, 3694170.0, 3665707.0, 3824976.0]}, "iteration_timing_avg": 0.5847132352941178} \ No newline at end of file +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19795, 9.20023, 9.19544, 9.17244, 9.11854, 9.1031, 9.04185, 8.98723, 8.94423, 8.84517]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3718669.0, 3825107.0, 3715731.0, 3623999.0, 3940369.0, 3720312.0, 3638182.0, 3695283.0, 3666175.0, 3826111.0]}, "iteration_timing_avg": 0.5847132352941178} \ No newline at end of file diff --git a/tests/unit_tests/models/test_llava_model.py b/tests/unit_tests/models/test_llava_model.py index babb7dd1ec..d503f6783b 100644 --- a/tests/unit_tests/models/test_llava_model.py +++ b/tests/unit_tests/models/test_llava_model.py @@ -69,47 +69,190 @@ def test_set_input_tensor(self): self.model.set_input_tensor(input_tensor) assert self.model.vision_model.decoder.input_tensor.shape == expected_shape + @pytest.mark.internal + def test_preprocess_data(self): + self.model.cuda() + + image_embedding_value = torch.tensor(123.0) + image_embeddings = image_embedding_value * torch.ones((577, 3, 128)).cuda() + + image_token_index = -200 + input_ids = torch.arange(0, 1024, dtype=torch.int).expand(4, 1024).cuda() + input_ids[0, 0] = image_token_index # image before text + input_ids[1, 100] = image_token_index # image in between + input_ids[2, -1] = image_token_index # image at the end + # input_ids[3] - no image + + language_embedding_value = torch.tensor(999.0) + language_embeddings = language_embedding_value * torch.ones((4, 1024, 128)).cuda() + + # Labels are input_ids shifted to left by one. + labels = torch.arange(1, 1025, dtype=torch.int).expand(4, 1024).cuda() + labels[1, 99] = image_token_index + labels[2, -2] = image_token_index + + loss_mask = torch.ones((4, 1024), dtype=torch.int).cuda() + # Mask some text inputs (the text mask should carry over) + loss_mask[:2, :10] = 0 + loss_mask[:2, 110:120] = 0 + + use_inference_kv_cache = False + + embeddings, labels, loss_mask = self.model._preprocess_data( + image_embeddings, + language_embeddings, + input_ids, + loss_mask, + labels, + use_inference_kv_cache, + image_token_index, + ) + + assert embeddings.shape == torch.Size((1600, 4, 128)) + assert labels.shape == torch.Size((4, 1600)) + assert loss_mask.shape == labels.shape + + # First sample where image is before text (index 0). + expected_embeddings = torch.empty(1600).cuda() + expected_embeddings[:577] = image_embedding_value + expected_embeddings[577:] = language_embedding_value + + expected_labels = torch.empty(1600, dtype=torch.int).cuda() + expected_labels[:576] = -100 + expected_labels[576:] = torch.arange(1, 1025, dtype=torch.int) + + expected_loss_mask = torch.empty(1600, dtype=torch.int).cuda() + expected_loss_mask[:577] = 0 + expected_loss_mask[577:586] = 0 + expected_loss_mask[586:686] = 1 + expected_loss_mask[686:696] = 0 + expected_loss_mask[696:] = 1 + + assert torch.allclose(embeddings[:, 0], expected_embeddings.unsqueeze(1)) + assert torch.allclose(labels[0], expected_labels) + assert torch.allclose(loss_mask[0], expected_loss_mask) + + # Second sample where image is in between (index 100). + expected_embeddings = torch.empty(1600).cuda() + expected_embeddings[:100] = language_embedding_value + expected_embeddings[100:677] = image_embedding_value + expected_embeddings[677:] = language_embedding_value + + expected_labels = torch.empty(1600, dtype=torch.int).cuda() + expected_labels[:99] = torch.arange(1, 100) + expected_labels[99:676] = -100 + expected_labels[676:] = torch.arange(101, 1025) + + expected_loss_mask = torch.empty(1600, dtype=torch.int).cuda() + expected_loss_mask[:10] = 0 + expected_loss_mask[10:99] = 1 + expected_loss_mask[99] = ( + 0 # Last text position before the image is not required to predict the first image embedding. + ) + expected_loss_mask[100:677] = 0 + expected_loss_mask[677:686] = 1 + expected_loss_mask[686:696] = 0 + expected_loss_mask[696:] = 1 + + assert torch.allclose(embeddings[:, 1], expected_embeddings.unsqueeze(1)) + assert torch.allclose(labels[1], expected_labels) + assert torch.allclose(loss_mask[1], expected_loss_mask) + + # Third sample where image is at the end. + expected_embeddings = torch.empty(1600).cuda() + expected_embeddings[:1023] = language_embedding_value + expected_embeddings[1023:] = image_embedding_value + + expected_labels = torch.empty(1600, dtype=torch.int).cuda() + expected_labels[:1022] = torch.arange(1, 1023) + expected_labels[1022:1599] = -100 + expected_labels[1599] = 1024 + + expected_loss_mask = torch.empty(1600, dtype=torch.int).cuda() + expected_loss_mask[:1022] = 1 + expected_loss_mask[1022] = ( + 0 # Last text position before the image is not required to predict the first image embedding. + ) + expected_loss_mask[1023:] = 0 + + assert torch.allclose(embeddings[:, 2], expected_embeddings.unsqueeze(1)) + assert torch.allclose(labels[2], expected_labels) + assert torch.allclose(loss_mask[2], expected_loss_mask) + + # Fourth sample where there is no image. + expected_embeddings = torch.empty(1600).cuda() + expected_embeddings[:1024] = language_embedding_value + expected_embeddings[1024:] = 0 # padding + + expected_labels = torch.empty(1600, dtype=torch.int).cuda() + expected_labels[:1024] = torch.arange(1, 1025) + expected_labels[1024:] = -100 + + expected_loss_mask = torch.empty(1600, dtype=torch.int).cuda() + expected_loss_mask[:1024] = 1 + expected_loss_mask[1024:] = 0 + + assert torch.allclose(embeddings[:, 3], expected_embeddings.unsqueeze(1)) + assert torch.allclose(labels[3], expected_labels) + assert torch.allclose(loss_mask[3], expected_loss_mask) + @pytest.mark.internal def test_forward(self): self.model.cuda() - img = torch.randn((2, 3, 336, 336)).cuda() - input_ids = torch.randint(0, 2048, (2, 1024)).cuda() - position_ids = torch.arange(0, 1024, dtype=torch.int).cuda() - position_ids = position_ids.expand(2, 1024) - # With default image and patch sizes of 336 and 14, respectively, and a class token, the combined sequence length is 1024 + (336/14) ** 2 + 1 = 1601. - attention_mask = torch.tril(torch.ones((2, 1, 1601, 1601))).cuda() - attention_mask = attention_mask < 0.5 - labels = torch.randint(0, 2048, (2, 1601)).cuda() + img = torch.randn((3, 3, 336, 336)).cuda() + + image_token_index = -200 + input_ids = torch.randint(0, 2048, (4, 1024)).cuda() + input_ids[0, 0] = image_token_index # image before text + input_ids[1, 100] = image_token_index # image in between + input_ids[2, -1] = image_token_index # image at the end + # input_ids[3] - no image + + position_ids = torch.arange(0, 1024, dtype=torch.int).expand(4, 1024).cuda() + + loss_mask = torch.ones((4, 1024)).cuda() + + attention_mask = None # Causal. + + labels = torch.randint(0, 2048, (4, 1024)).cuda() + labels[1, 99] = image_token_index + labels[2, -2] = image_token_index # Try with labels. - loss = self.model.forward(img, input_ids, position_ids, attention_mask, labels=labels) - assert loss.shape == torch.Size((2, 1601)) + loss, new_loss_mask = self.model.forward( + img, input_ids, position_ids, attention_mask, labels, loss_mask + ) + # The final sequence length 1600 comes from 577 image tokens and 1023 text tokens. + assert loss.shape == new_loss_mask.shape == torch.Size((4, 1600)) # Try without labels and without inference params. - logits = self.model.forward(img, input_ids, position_ids, attention_mask, labels=None) - assert logits.shape == torch.Size((2, 1601, 2048)) + logits = self.model.forward( + img, input_ids, position_ids, attention_mask, labels=None, loss_mask=None + ) + assert logits.shape == torch.Size((4, 1600, 2048)) # Try without labels and with inference params. - inference_params = InferenceParams(2, 1601) + inference_params = InferenceParams(4, 1600) logits = self.model.forward( img, input_ids, position_ids, attention_mask, labels=None, + loss_mask=None, inference_params=inference_params, ) - assert logits.shape == torch.Size((2, 1601, 2048)) + assert logits.shape == torch.Size((4, 1600, 2048)) - # Check KV cache got created correctly. + # Check KV cache got populated correctly. kv_dict = inference_params.key_value_memory_dict assert kv_dict["image_tokens_count"] == 577 for layer_no in range(1, 4): # 3 layers in the model. layer_kv = kv_dict[layer_no] # Expected shape is [sequence_len, batch_size, num_heads, hidden_size_per_head] - assert layer_kv[0].shape == layer_kv[1].shape == torch.Size((1601, 2, 8, 16)) + assert layer_kv[0].shape == layer_kv[1].shape == torch.Size((1600, 4, 8, 16)) @pytest.mark.internal def test_save_load(self, tmp_path):