Skip to content

Commit

Permalink
Merge branch 'trintamaki/llava-pp-fix' into 'main'
Browse files Browse the repository at this point in the history
Llava pp > 1 fix

See merge request ADLR/megatron-lm!2441
  • Loading branch information
ericharper committed Dec 13, 2024
2 parents 40db706 + de18820 commit 183f568
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions examples/multimodal/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_batch(data_iterator):
pp_size = get_pipeline_model_parallel_world_size()
if not is_first_or_last_stage(pp_size, args.encoder_pipeline_model_parallel_size):
# Note these are all set to None above.
return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles
return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles, packed_seq_params

# Broadcast data.
torch.cuda.nvtx.range_push("get_data")
Expand All @@ -66,7 +66,7 @@ def get_batch(data_iterator):
cu_lengths = tensor_parallel.broadcast_data(["cu_lengths"], data, torch.int32)["cu_lengths"]
max_lengths = tensor_parallel.broadcast_data(["max_lengths"], data, torch.int32)["max_lengths"]

# Dummy image, no image.
# No image input (text-only sample) if the dataloader produced a dummy image.
if imgs.shape == torch.Size([1, 1]):
# FIXME: text-only data can cause a hang if the vision model is own its own pipeline rank and --freeze-ViT is enabled.
imgs = torch.tensor([], dtype=torch.float32, device=data_text.device)
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ def forward(
).contiguous() # [b, text_seq_len, h_language]

# Assume 1 tile per image if the number of tiles is not provided.
if num_image_tiles is None:
if num_image_tiles is None and images is not None:
num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device)

combined_embeddings, new_labels, new_loss_mask = self._preprocess_data(
Expand Down

0 comments on commit 183f568

Please sign in to comment.