Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 172 additions & 15 deletions nemo/collections/vlm/gemma3vl/data/task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import json
import logging
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from megatron.energon import VQASample
from megatron.energon import InterleavedSample, VQASample

from nemo.collections.vlm.data.task_encoder import DataBatch, DataSample
from nemo.collections.vlm.data.task_encoder import TaskEncoder as BaseTaskEncoder
Expand All @@ -39,6 +39,8 @@ class TaskEncoderConfig(BaseTaskEncoderConfig):

stop_string: Optional[str] = ""
system_prompt: Optional[str] = None
image_token_str: str = "<start_of_image>"
image_token_id: int = 262144 # This is the token id for <image_soft_token>


@dataclass
Expand All @@ -60,7 +62,7 @@ class TaskEncoder(BaseTaskEncoder):

The encoder supports:
- VQA samples: Processing image-question pairs with corresponding answers
- [In progress] Interleaved samples: Processing alternating image and text content
- Interleaved samples: Processing alternating image and text content
- [In progress] Similarity interleaved samples: Processing image-text pairs for similarity tasks
- [In progress] Packed sequences: Efficient processing of multiple samples in a single sequence

Expand All @@ -85,6 +87,7 @@ def __init__(self, config: TaskEncoderConfig):
# Initialize encoders with the config
self.encoders = {
"VQASample": self.encode_vqa_sample,
"InterleavedSample": self.encode_interleaved_sample,
}

def encode_batch(self, batch_data: DataBatch) -> dict:
Expand Down Expand Up @@ -200,20 +203,10 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:

# Pad tokens and labels to a multiple of `pad_to_multiple_of` if specified
if self.config.pad_to_multiple_of:
seqlen_padded = (
(seqlen + self.config.pad_to_multiple_of - 1)
// self.config.pad_to_multiple_of
* self.config.pad_to_multiple_of
)
pad_len = seqlen_padded - seqlen

if pad_len > 0:
tokens = F.pad(tokens, (0, pad_len), 'constant', 0)
labels = F.pad(labels, (0, pad_len), 'constant', self.config.ignore_place_holder)
tokens, labels = self.pad_tokens_and_labels(tokens, labels, seqlen)

# Compute loss mask
loss_mask = torch.ones_like(labels, dtype=torch.float)
loss_mask[labels < 0] = 0.0
loss_mask = self.compute_loss_mask(labels)

# Convert images to bfloat16 and stack, or create an empty tensor if no images
if images is not None and images.numel() > 0:
Expand All @@ -237,3 +230,167 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
)

return sample

def tokenize_interleaved_sample(
self, input_sample: InterleavedSample
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Tokenize the input sequence and process images in an interleaved sample.

This method processes a sequence that consists of text strings and image tensors.
The text is tokenized, and the images are processed. The method returns a tensor
of tokenized text and a concatenated tensor of processed images.

Parameters:
sample (InterleavedSample): The interleaved sample containing a sequence of text strings and image tensors.

Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- A tensor with tokenized text and image token IDs.
- A concatenated tensor of processed images.
"""
texts, images = [], []
for item in input_sample.sequence:
if type(item) == str:
texts.append(item)
elif type(item) == torch.Tensor:
images.append(item)
texts.append(
self.config.image_token_str
) # Append start token to the last text. HF Processor will replace this token with the actual image tokens during processing.
else:
raise ValueError(f"Unsupported item type in interleaved sequence: {type(item)}")

outputs = self.hf_processor(
images=[images], # images is a batched to size of one.
text=" ".join(texts),
return_tensors="pt",
images_kwargs={"do_rescale": False},
)
# Get tokens and images from processor output
# Squeeze the batch dimension as we process one sample at a time
tokens = outputs["input_ids"].squeeze(0)
images = outputs.get("pixel_values") # Use .get() for optional images

# Convert images to bfloat16 and stack, or create an empty tensor if no images
if images is not None and images.numel() > 0:
# Ensure images tensor is on the same device as tokens/labels if needed
images = images.to(device=tokens.device, dtype=torch.bfloat16)
processed_images = images # Already stacked by HF processor if multiple images/frames
else:
# Create an empty tensor with appropriate dimensions and dtype if no images
processed_images = None
return tokens, processed_images

def compute_labels_interleaved(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Compute labels for an interleaved sample, ignoring image token IDs.

This method generates a label tensor where the tokens corresponding to images are marked
with the `ignore_place_holder` ID, and other tokens retain their original IDs.

Parameters:
tokens (torch.Tensor): A tensor containing the tokenized sequence.

Returns:
torch.Tensor: A tensor containing the labels for the tokenized sequence.
"""
labels = tokens.clone()
labels[labels == self.config.image_token_id] = self.config.ignore_place_holder
labels = labels[1:].contiguous()
return labels

def pad_tokens_and_labels(
self, tokens: torch.Tensor, labels: torch.Tensor, seqlen: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pad tokens and labels to a be a multiple of config.pad_to_multiple_of

Parameters:
tokens (torch.Tensor): A tensor containing the tokenized sequence.
labels (torch.Tensor): A tensor containing the labels for the tokenized sequence.
seqlen (int): Original sequence length before padding

Returns:
Tuple[torch.Tensor, torch.Tensor]: Tokens and labels tensor padded to a multiple of config.pad_to_multiple_of
"""
seqlen_padded = (
(seqlen + self.config.pad_to_multiple_of - 1)
// self.config.pad_to_multiple_of
* self.config.pad_to_multiple_of
)
pad_len = seqlen_padded - seqlen

if pad_len > 0:
tokens = F.pad(tokens, (0, pad_len), 'constant', 0)
labels = F.pad(labels, (0, pad_len), 'constant', self.config.ignore_place_holder)
return tokens, labels

def compute_loss_mask(self, labels: torch.Tensor) -> torch.Tensor:
"""
Compute the loss mask based on which label values are negative.

Parameters:
labels (torch.Tensor): A tensor containing the labels for the tokenized sequence.

Returns
torch.Tensor: The computed loss mask.
"""

loss_mask = torch.ones_like(labels, dtype=torch.float)
loss_mask[labels < 0] = 0.0
return loss_mask

def encode_interleaved_sample(self, input_sample: InterleavedSample) -> DataSample:
"""
Encode an interleaved sample.

This method tokenizes the input sequence, computes labels and a loss mask, and processes
the images. The encoded sample is then stored in the output_sample object.

Parameters:
input_sample (InterleavedSample): The interleaved sample to be encoded.

Returns:
DataSample: Encoded sample with processed image, tokens, labels and loss mask
"""
logging.info(f"The config is: {self.config}")
logging.info(f"input_sample={input_sample}")
tokens, processed_images = self.tokenize_interleaved_sample(input_sample)

logging.info(f"decode encoded tokens: {self.tokenizer.tokenizer.decode(tokens)}")

# --- Label Generation ---
labels = self.compute_labels_interleaved(tokens)

logging.info(f"encoded:===== input_sample={input_sample}, tokens={tokens}, labels={labels}")
# Prepare final tensors
tokens = tokens[:-1]
seqlen = len(tokens) # Original sequence length before padding
position_ids = torch.arange(seqlen, dtype=torch.int64)

logging.debug(f"data encoder: position_ids = {position_ids}")

# Pad tokens and labels to a multiple of `pad_to_multiple_of` if specified
if self.config.pad_to_multiple_of:
tokens, labels = self.pad_tokens_and_labels(tokens, labels, seqlen)

# Compute loss mask
loss_mask = self.compute_loss_mask(labels)

logging.debug(f"There are {(labels > 0).sum()} valid labels.")

sample = Gemma3DataSample(
__key__=input_sample.__key__,
__restore_key__=input_sample.__restore_key__,
__subflavor__=input_sample.__subflavor__,
__subflavors__=input_sample.__subflavors__,
pixel_values=processed_images,
input_ids=tokens,
position_ids=position_ids,
labels=labels,
loss_mask=loss_mask,
)
logging.debug(f"Gemma3 task encoder: sample: {sample.position_ids}")

return sample
Loading