Skip to content

Commit

Permalink
ADLR/megatron-lm!1402 - LLaVA expanded data processing
Browse files Browse the repository at this point in the history
  • Loading branch information
trintamaki authored and ko3n1g committed Aug 15, 2024
1 parent 44cc262 commit b1e36c4
Show file tree
Hide file tree
Showing 10 changed files with 455 additions and 111 deletions.
14 changes: 9 additions & 5 deletions .gitlab/stages/01.tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ unit_tests:
parallel:
matrix:
- TAG: latest
- TAG: db5c60ae3fe5247f16ec0536bbf41ee5c7fb9c4a
- TAG: a5efe829b1d34c691f0a7a5286e271b4f9c86b2a
tags: [8xL40S]
variables:
GIT_STRATEGY: clone
Expand All @@ -89,11 +89,15 @@ unit_tests:
cp -r tests/ /opt/megatron-lm
fi
script:
- |
cd /opt/megatron-lm
- |
cd /opt/megatron-lm
for i in $(seq $UNIT_TEST_REPEAT); do
SEED=$((RANDOM % 9000 + 1000));
timeout ${UNIT_TEST_TIMEOUT}m torchrun --nproc_per_node=8 -m pytest --random-order --random-order-seed ${SEED} -xvs --cov-report=term --cov-report=html --cov=megatron/core --no-cov-on-fail `$([[ $TAG != latest ]] && echo -m 'not internal')` tests/unit_tests
SKIPPED=()
if [[ $TAG != latest ]]; then
SKIPPED+=(-m "not internal")
fi
timeout ${UNIT_TEST_TIMEOUT}m torchrun --nproc_per_node=8 -m pytest --random-order --random-order-seed ${SEED} -xvs --cov-report=term --cov-report=html --cov=megatron/core --no-cov-on-fail "${SKIPPED[@]}" tests/unit_tests
done
artifacts:
paths:
Expand Down Expand Up @@ -143,7 +147,7 @@ secret_detection:
- apk add jq
- /analyzer run
- |
if [[ $(cat gl-secret-detection-report.json | jq '.vulnerabilities | length > 0') == true ]]; then
if [[ $(cat gl-secret-detection-report.json | jq '.vulnerabilities | length > 0') == true ]]; then
echo "Atleast one vulnerability has been found"
cat gl-secret-detection-report.json | jq '.'
exit 1
Expand Down
10 changes: 9 additions & 1 deletion examples/multimodal/run_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchvision.transforms import Compose, Resize, ToPILImage
from train import add_multimodal_extra_args, get_image_token_count, model_provider

from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN_INDEX
from megatron.inference.text_generation.api import generate_and_post_process
from megatron.inference.text_generation.forward_step import ForwardStep
from megatron.training import get_args, get_model, print_rank_0
Expand Down Expand Up @@ -282,7 +283,7 @@ def generate_samples(model):
elif args.task in ("TextVQA", "MMMU"):
output_name = "text"

generated = generation[len(prompt) + 1 :]
generated = generation[len(prompt):]
output[output_name] = generated

if args.task == "captioning":
Expand Down Expand Up @@ -329,6 +330,13 @@ def __init__(self, images, num_image_tokens, model, max_batch_size, max_sequence
self._images = images

def _forward(self, tokens, position_ids, attention_mask):
# Add image token index to the front if it's not included in the prompt. Note: This will change in a future MR.
num_tokens = tokens.shape[1]

if num_tokens > 1 and torch.sum(tokens == IMAGE_TOKEN_INDEX).item() == 0:
tokens = torch.cat([torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=tokens.dtype, device=tokens.device), tokens], dim=1)
position_ids = torch.arange(num_tokens, dtype=position_ids.dtype, device=position_ids.device)

return self.model(
self._images,
tokens,
Expand Down
31 changes: 2 additions & 29 deletions examples/multimodal/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def model_provider(
vision_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size
vision_projection_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size
if args.encoder_tensor_model_parallel_size > 0:
vision_transformer_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size
vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size
vision_projection_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size

vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules
Expand Down Expand Up @@ -113,7 +113,6 @@ def model_provider(
img_w=args.img_w,
patch_dim=args.patch_dim,
language_rotary_base=args.rotary_base,
img_embedding_idx=args.img_embedding_idx,
)

model.freeze(freeze_language_model=args.freeze_LM, freeze_vision_model=args.freeze_ViT, freeze_vision_projection=False)
Expand Down Expand Up @@ -171,10 +170,6 @@ def get_batch(data_iterator):
question_length=prompt_len)
torch.cuda.nvtx.range_pop()

loss_mask, labels, attention_mask = _preprocess_data_for_llava(loss_mask, labels, attention_mask)

tokens = tokens[:, 1:] # drop image index token

return tokens, labels, loss_mask, attention_mask, position_ids, img_raw


Expand All @@ -191,24 +186,6 @@ def get_image_token_count():
return num_image_tokens


def _preprocess_data_for_llava(loss_mask, labels, attention_mask):
"""Preprocess data sample to the format expected by a LLaVA model."""
num_image_tokens = get_image_token_count()

batch_size = loss_mask.shape[0]

loss_mask2 = torch.cat(
[torch.zeros(batch_size, num_image_tokens - 1, dtype=torch.float32, device=loss_mask.device), loss_mask], dim=1
)
labels2 = torch.cat([torch.zeros(batch_size, num_image_tokens - 1, dtype=torch.int64, device=labels.device), labels], dim=1)

full_seq_length = len(labels2[0])
attention_mask2 = torch.tril(torch.ones((1, 1, full_seq_length, full_seq_length), device=attention_mask.device))
attention_mask2 = attention_mask2 < 0.5

return loss_mask2, labels2, attention_mask2


def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
Expand Down Expand Up @@ -312,7 +289,7 @@ def forward_step(data_iterator, model: LLaVAModel):
tokens, labels, loss_mask, attention_mask, position_ids, images = get_batch(data_iterator)
timers('batch-generator').stop()

output_tensor = model(images, tokens, position_ids, attention_mask, labels=labels)
output_tensor, loss_mask = model(images, tokens, position_ids, attention_mask, labels, loss_mask)

return output_tensor, partial(loss_func, loss_mask)

Expand All @@ -332,10 +309,6 @@ def add_multimodal_extra_args(parser):
group.add_argument("--disable-vision-class-token", action="store_true", default=False)
group.add_argument("--allow-missing-vision-projection-checkpoint", action="store_true", default=False)
group.add_argument("--use-te", action="store_true", default=False)
group.add_argument("--img-embedding-idx", type=int, default=0,
help='Llava specific parameter. Defines at which index'
'in the language_embedding tensor the image_embeddings'
'should be inserted')
group.add_argument("--dataloader-save", type=str, default=None, help="Energon dataloader state save path")
return parser

Expand Down
Loading

0 comments on commit b1e36c4

Please sign in to comment.