diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index 0594793e..db098494 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -1,3 +1,4 @@ +from .florence2 import ColFlor, ColFlorProcessor from .idefics_2 import BiIdefics2, ColIdefics2, ColIdefics2Processor from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor diff --git a/colpali_engine/models/florence2/__init__.py b/colpali_engine/models/florence2/__init__.py new file mode 100644 index 00000000..38a45bc5 --- /dev/null +++ b/colpali_engine/models/florence2/__init__.py @@ -0,0 +1 @@ +from .colflor import ColFlor, ColFlorProcessor diff --git a/colpali_engine/models/florence2/colflor/__init__.py b/colpali_engine/models/florence2/colflor/__init__.py new file mode 100644 index 00000000..3c3c1050 --- /dev/null +++ b/colpali_engine/models/florence2/colflor/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colflor import ColFlor +from .processing_colflor import ColFlorProcessor diff --git a/colpali_engine/models/florence2/colflor/configuration_florence2.py b/colpali_engine/models/florence2/colflor/configuration_florence2.py new file mode 100644 index 00000000..b7250ebf --- /dev/null +++ b/colpali_engine/models/florence2/colflor/configuration_florence2.py @@ -0,0 +1,339 @@ +# ruff: noqa +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings + +""" Florence-2 configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +class Florence2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Florence2VisionModel architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + drop_path_rate (`float`, *optional*, defaults to 0.1): + The dropout rate of the drop path layer. + patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]): + The patch size of the image. + patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]): + The patch stride of the image. + patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]): + The patch padding of the image. + patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]): + Whether to apply layer normalization before the patch embedding layer. + enable_checkpoint (`bool`, *optional*, defaults to False): + Whether to enable checkpointing. + dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]): + The dimension of the embedding layer. + num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]): + The number of attention heads. + num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]): + The number of groups. + depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]): + The depth of the model. + window_size (`int`, *optional*, defaults to 12): + The window size of the model. + projection_dim (`int`, *optional*, defaults to 1024): + The dimension of the projection layer. + visual_temporal_embedding (`dict`, *optional*): + The configuration of the visual temporal embedding. + image_pos_embed (`dict`, *optional*): + The configuration of the image position embedding. + image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]): + The source of the image feature. + Example: + + ```python + >>> from transformers import Florence2VisionConfig, Florence2VisionModel + + >>> # Initializing a Florence2 Vision style configuration + >>> configuration = Florence2VisionConfig() + + >>> # Initializing a model (with random weights) + >>> model = Florence2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "florence2_vision" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + drop_path_rate=0.1, + patch_size=[7, 3, 3, 3], + patch_stride=[4, 2, 2, 2], + patch_padding=[3, 1, 1, 1], + patch_prenorm=[False, True, True, True], + enable_checkpoint=False, + dim_embed=[256, 512, 1024, 2048], + num_heads=[8, 16, 32, 64], + num_groups=[8, 16, 32, 64], + depths=[1, 1, 9, 1], + window_size=12, + projection_dim=1024, + visual_temporal_embedding=None, + image_pos_embed=None, + image_feature_source=["spatial_avg_pool", "temporal_avg_pool"], + **kwargs, + ): + self.drop_path_rate = drop_path_rate + self.patch_size = patch_size + self.patch_stride = patch_stride + self.patch_padding = patch_padding + self.patch_prenorm = patch_prenorm + self.enable_checkpoint = enable_checkpoint + self.dim_embed = dim_embed + self.num_heads = num_heads + self.num_groups = num_groups + self.depths = depths + self.window_size = window_size + self.projection_dim = projection_dim + self.visual_temporal_embedding = visual_temporal_embedding + self.image_pos_embed = image_pos_embed + self.image_feature_source = image_feature_source + + super().__init__(**kwargs) + + + +class Florence2LanguageConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the BART + [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 51289): + Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Florence2LanguageModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + num_labels (`int`, *optional*, defaults to 3): + The number of labels to use in [`Florence2LanguageForSequenceClassification`]. + forced_eos_token_id (`int`, *optional*, defaults to 2): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + + Example: + + ```python + >>> from transformers import Florence2LanguageConfig, Florence2LanguageModel + + >>> # Initializing a Florence2 Language style configuration + >>> configuration = Florence2LanguageConfig() + + >>> # Initializing a model (with random weights) + >>> model = Florence2LangaugeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "florence2_language" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=51289, + max_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + activation_function="gelu", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + num_labels=3, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + is_encoder_decoder=True, + decoder_start_token_id=2, + forced_eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__( + num_labels=num_labels, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + # ensure backward compatibility for BART CNN models + if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): + self.forced_bos_token_id = self.bos_token_id + warnings.warn( + f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. " + "The config can simply be saved and uploaded again to be fixed." + ) + +class Florence2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an + Florence-2 model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Florence2VisionConfig`, *optional*): + Custom vision config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + vocab_size (`int`, *optional*, defaults to 51289): + Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`] + projection_dim (`int`, *optional*, defaults to 1024): + Dimension of the multimodal projection space. + + Example: + + ```python + >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig + + >>> # Initializing a clip-like vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Bart config + >>> text_config = BartConfig() + + >>> # Initializing a Florence-2 configuration + >>> configuration = Florence2Config(vision_config, text_config) + + >>> # Initializing a model from the florence-2 configuration + >>> model = Florence2ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "florence2" + is_composition = False + + def __init__( + self, + vision_config=None, + text_config=None, + ignore_index=-100, + vocab_size=51289, + projection_dim=1024, + **kwargs, + ): + self.ignore_index = ignore_index + self.vocab_size = vocab_size + self.projection_dim = projection_dim + if vision_config is not None: + vision_config = PretrainedConfig(**vision_config) + self.vision_config = vision_config + self.vocab_size = self.vocab_size + + self.text_config = text_config + if text_config is not None: + self.text_config = Florence2LanguageConfig(**text_config) + + + super().__init__(**kwargs) diff --git a/colpali_engine/models/florence2/colflor/modeling_colflor.py b/colpali_engine/models/florence2/colflor/modeling_colflor.py new file mode 100644 index 00000000..ae2de839 --- /dev/null +++ b/colpali_engine/models/florence2/colflor/modeling_colflor.py @@ -0,0 +1,45 @@ +from typing import ClassVar + +import torch +from torch import nn + +from .configuration_florence2 import Florence2Config +from .modeling_florence2 import Florence2VisionLanguageModel + + +class ColFlor(Florence2VisionLanguageModel): + """ + ColFlor model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. + """ + + main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related + + def __init__(self, config: Florence2Config): + super().__init__(config=config) + + self.dim = 128 + self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim) + + self.padding_side = "right" + self.post_init() + + def forward(self, *args, **kwargs) -> torch.Tensor: + # Delete output_hidden_states from kwargs + kwargs.pop("output_hidden_states", None) + + # Create Full Attention Mask that includes both the image and text + full_attention_mask = kwargs['attention_mask'] + # make sure pixel_values are in the same dtype as the model + if 'pixel_values' in kwargs: + full_attention_mask = kwargs['full_attention_mask'].type(self.dtype) + del kwargs['full_attention_mask'] + kwargs['pixel_values'] = kwargs['pixel_values'].type(self.dtype) + + outputs = super().forward(*args, **kwargs) + + last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size) + proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) + # L2 normalization + proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + return proj diff --git a/colpali_engine/models/florence2/colflor/modeling_florence2.py b/colpali_engine/models/florence2/colflor/modeling_florence2.py new file mode 100644 index 00000000..4603abf9 --- /dev/null +++ b/colpali_engine/models/florence2/colflor/modeling_florence2.py @@ -0,0 +1,3111 @@ +# ruff: noqa +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" PyTorch Florence-2 model.""" +import math +from collections import OrderedDict +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torch.utils.checkpoint as checkpoint +from einops import rearrange +import importlib +try: + from timm.models.layers import DropPath, trunc_normal_ +except ImportError: + pass +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from .configuration_florence2 import Florence2Config, Florence2LanguageConfig, Florence2VisionConfig + +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Florence2Config" + +class LearnedAbsolutePositionEmbedding2D(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, embedding_dim=256, num_pos=50): + super().__init__() + self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) + self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2)) + + def forward(self, pixel_values): + """ + pixel_values: (batch_size, height, width, num_channels) + returns: (batch_size, height, width, embedding_dim * 2) + """ + if len(pixel_values.shape) != 4: + raise ValueError('pixel_values must be a 4D tensor') + height, width = pixel_values.shape[1:3] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + # (height, width, embedding_dim * 2) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + # (embedding_dim * 2, height, width) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + # (batch_size, embedding_dim * 2, height, width) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + # (batch_size, height, width, embedding_dim * 2) + pos = pos.permute(0, 2, 3, 1) + return pos + +class PositionalEmbeddingCosine1D(nn.Module): + """ + This class implements a very simple positional encoding. It follows closely + the encoder from the link below: + https://pytorch.org/tutorials/beginner/translation_transformer.html + + Args: + embed_dim: The dimension of the embeddings. + dropout_prob: The dropout probability. + max_seq_len: The maximum length to precompute the positional encodings. + """ + def __init__( + self, + embed_dim: int = 512, + max_seq_len: int = 1024) -> None: + super(PositionalEmbeddingCosine1D, self).__init__() + self.embed_dim = embed_dim + self.max_seq_len = max_seq_len + # Generate the sinusoidal arrays. + factor = math.log(10000) + denominator = torch.exp( + -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim) + # Matrix where rows correspond to a positional embedding as a function + # of the position index (i.e., the row index). + frequencies = \ + torch.arange(0, self.max_seq_len) \ + .reshape(self.max_seq_len, 1) * denominator + pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) + # Populate uneven entries. + pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) + pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) + # Save the positional embeddings in a constant buffer. + self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) + + def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: + """ + Args: + seq_embeds: The sequence embeddings in order. Allowed size: + 1. [T, D], where T is the length of the sequence, and D is the + frame embedding dimension. + 2. [B, T, D], where B is the batch size and T and D are the + same as above. + + Returns a tensor of with the same dimensions as the input: i.e., + [1, T, D] or [T, D]. + """ + shape_len = len(seq_embeds.shape) + assert 2 <= shape_len <= 3 + len_seq = seq_embeds.size(-2) + assert len_seq <= self.max_seq_len + pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] + # Adapt pre-computed positional embeddings to the input. + if shape_len == 3: + pos_embeds = pos_embeds.view( + (1, pos_embeds.size(0), pos_embeds.size(1))) + return pos_embeds + + +class LearnedAbsolutePositionEmbedding1D(nn.Module): + """ + Learnable absolute positional embeddings for 1D sequences. + + Args: + embed_dim: The dimension of the embeddings. + max_seq_len: The maximum length to precompute the positional encodings. + """ + def __init__( + self, + embedding_dim: int = 512, + num_pos: int = 1024) -> None: + super(LearnedAbsolutePositionEmbedding1D, self).__init__() + self.embeddings = nn.Embedding(num_pos, embedding_dim) + self.num_pos = num_pos + + def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: + """ + Args: + seq_embeds: The sequence embeddings in order. Allowed size: + 1. [T, D], where T is the length of the sequence, and D is the + frame embedding dimension. + 2. [B, T, D], where B is the batch size and T and D are the + same as above. + + Returns a tensor of with the same dimensions as the input: i.e., + [1, T, D] or [T, D]. + """ + shape_len = len(seq_embeds.shape) + assert 2 <= shape_len <= 3 + len_seq = seq_embeds.size(-2) + assert len_seq <= self.num_pos + # [T, D] + pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device)) + # Adapt pre-computed positional embeddings to the input. + if shape_len == 3: + pos_embeds = pos_embeds.view( + (1, pos_embeds.size(0), pos_embeds.size(1))) + return pos_embeds + + + +class MySequential(nn.Sequential): + def forward(self, *inputs): + for module in self._modules.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class PreNorm(nn.Module): + def __init__(self, norm, fn, drop_path=None): + super().__init__() + self.norm = norm + self.fn = fn + self.drop_path = drop_path + + def forward(self, x, *args, **kwargs): + shortcut = x + if self.norm != None: + x, size = self.fn(self.norm(x), *args, **kwargs) + else: + x, size = self.fn(x, *args, **kwargs) + + if self.drop_path: + x = self.drop_path(x) + + x = shortcut + x + + return x, size + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.net = nn.Sequential(OrderedDict([ + ("fc1", nn.Linear(in_features, hidden_features)), + ("act", act_layer()), + ("fc2", nn.Linear(hidden_features, out_features)) + ])) + + def forward(self, x, size): + return self.net(x), size + + +class DepthWiseConv2d(nn.Module): + def __init__( + self, + dim_in, + kernel_size, + padding, + stride, + bias=True, + ): + super().__init__() + self.dw = nn.Conv2d( + dim_in, dim_in, + kernel_size=kernel_size, + padding=padding, + groups=dim_in, + stride=stride, + bias=bias + ) + + def forward(self, x, size): + B, N, C = x.shape + H, W = size + assert N == H * W + + x = self.dw(x.transpose(1, 2).view(B, C, H, W)) + size = (x.size(-2), x.size(-1)) + x = x.flatten(2).transpose(1, 2) + return x, size + + +class ConvEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__( + self, + patch_size=7, + in_chans=3, + embed_dim=64, + stride=4, + padding=2, + norm_layer=None, + pre_norm=True + ): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_chans, embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding + ) + + dim_norm = in_chans if pre_norm else embed_dim + self.norm = norm_layer(dim_norm) if norm_layer else None + + self.pre_norm = pre_norm + + def forward(self, x, size): + H, W = size + if len(x.size()) == 3: + if self.norm and self.pre_norm: + x = self.norm(x) + x = rearrange( + x, 'b (h w) c -> b c h w', + h=H, w=W + ) + + x = self.proj(x) + + _, _, H, W = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + if self.norm and not self.pre_norm: + x = self.norm(x) + + return x, (H, W) + + +class ChannelAttention(nn.Module): + + def __init__(self, dim, groups=8, qkv_bias=True): + super().__init__() + + self.groups = groups + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, size): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * (float(N) ** -0.5) + attention = q.transpose(-1, -2) @ k + attention = attention.softmax(dim=-1) + x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x, size + + +class ChannelBlock(nn.Module): + + def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True, + drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, + conv_at_attn=True, conv_at_ffn=True): + super().__init__() + + drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None + self.channel_attn = PreNorm( + norm_layer(dim), + ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), + drop_path + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), + drop_path + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.channel_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + + return x, size + + +def window_partition(x, window_size: int): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): + B = batch_size + # this will cause onnx conversion failed for dynamic axis, because treated as constant + # int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + def __init__(self, dim, num_heads, window_size, qkv_bias=True): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = float(head_dim) ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, size): + + H, W = size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x = window_partition(x, self.window_size) + x = x.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # attn_windows = self.attn(x_windows) + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + + # merge windows + x = x.view( + -1, self.window_size, self.window_size, C + ) + x = window_reverse(x, B, self.window_size, Hp, Wp) + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + return x, size + + +class SpatialBlock(nn.Module): + + def __init__(self, dim, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True): + super().__init__() + + drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None + self.window_attn = PreNorm( + norm_layer(dim), + WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), + drop_path + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), + drop_path + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.window_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + return x, size + + +class DaViT(nn.Module): + """ DaViT: Dual-Attention Transformer + + Args: + in_chans (int): Number of input image channels. Default: 3. + num_classes (int): Number of classes for classification head. Default: 1000. + patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2). + patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2). + patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0). + patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False). + embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256). + num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16). + num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16). + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. + drop_path_rate (float): Stochastic depth rate. Default: 0.1. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + enable_checkpoint (bool): If True, enable checkpointing. Default: False. + conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True. + conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True. + """ + + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=(1, 1, 3, 1), + patch_size=(7, 2, 2, 2), + patch_stride=(4, 2, 2, 2), + patch_padding=(3, 0, 0, 0), + patch_prenorm=(False, False, False, False), + embed_dims=(64, 128, 192, 256), + num_heads=(3, 6, 12, 24), + num_groups=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + enable_checkpoint=False, + conv_at_attn=True, + conv_at_ffn=True, + ): + super().__init__() + + self.num_classes = num_classes + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_groups = num_groups + self.num_stages = len(self.embed_dims) + self.enable_checkpoint = enable_checkpoint + assert self.num_stages == len(self.num_heads) == len(self.num_groups) + + num_stages = len(embed_dims) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)] + + depth_offset = 0 + convs = [] + blocks = [] + for i in range(num_stages): + conv_embed = ConvEmbed( + patch_size=patch_size[i], + stride=patch_stride[i], + padding=patch_padding[i], + in_chans=in_chans if i == 0 else self.embed_dims[i - 1], + embed_dim=self.embed_dims[i], + norm_layer=norm_layer, + pre_norm=patch_prenorm[i] + ) + convs.append(conv_embed) + + block = MySequential( + *[ + MySequential(OrderedDict([ + ( + 'spatial_block', SpatialBlock( + embed_dims[i], + num_heads[i], + window_size, + drop_path_rate=dpr[depth_offset+j*2], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ) + ), + ( + 'channel_block', ChannelBlock( + embed_dims[i], + num_groups[i], + drop_path_rate=dpr[depth_offset+j*2+1], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ) + ) + ])) for j in range(depths[i]) + ] + ) + blocks.append(block) + depth_offset += depths[i]*2 + + self.convs = nn.ModuleList(convs) + self.blocks = nn.ModuleList(blocks) + + self.norms = norm_layer(self.embed_dims[-1]) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + @property + def dim_out(self): + return self.embed_dims[-1] + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.02) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) + + def forward_features_unpool(self, x): + """ + forward until avg pooling + Args: + x (_type_): input image tensor + """ + input_size = (x.size(2), x.size(3)) + for conv, block in zip(self.convs, self.blocks): + x, input_size = conv(x, input_size) + if self.enable_checkpoint: + x, input_size = checkpoint.checkpoint(block, x, input_size) + else: + x, input_size = block(x, input_size) + return x + + def forward_features(self, x): + x = self.forward_features_unpool(x) + + # (batch_size, num_tokens, token_dim) + x = self.avgpool(x.transpose(1, 2)) + # (batch_size, 1, num_tokens) + x = torch.flatten(x, 1) + x = self.norms(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + @classmethod + def from_config(cls, config): + return cls( + depths=config.depths, + embed_dims=config.dim_embed, + num_heads=config.num_heads, + num_groups=config.num_groups, + patch_size=config.patch_size, + patch_stride=config.patch_stride, + patch_padding=config.patch_padding, + patch_prenorm=config.patch_prenorm, + drop_path_rate=config.drop_path_rate, + window_size=config.window_size, + ) + + + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class Florence2LearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Florence2 is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + + return super().forward(positions + self.offset) + + +class Florence2ScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +class Florence2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[Florence2LanguageConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class Florence2FlashAttention2(Florence2Attention): + """ + Florence2 flash attention module. This module inherits from `Florence2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Florence2FlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("Florence2FlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class Florence2SdpaAttention(Florence2Attention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Florence2Model is using Florence2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +FLORENCE2_ATTENTION_CLASSES = { + "eager": Florence2Attention, + "sdpa": Florence2SdpaAttention, + "flash_attention_2": Florence2FlashAttention2, +} + + +class Florence2EncoderLayer(nn.Module): + def __init__(self, config: Florence2LanguageConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Florence2DecoderLayer(nn.Module): + def __init__(self, config: Florence2LanguageConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + + +class Florence2LanguagePreTrainedModel(PreTrainedModel): + config_class = Florence2LanguageConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] + _no_split_modules = [r"Florence2EncoderLayer", r"Florence2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class Florence2Encoder(Florence2LanguagePreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`Florence2EncoderLayer`]. + + Args: + config: Florence2LanguageConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = Florence2ScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = Florence2LearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([Florence2EncoderLayer(config) for _ in range(config.encoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + if self._use_flash_attention_2: + attention_mask = attention_mask if 0 in attention_mask else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class Florence2Decoder(Florence2LanguagePreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`Florence2DecoderLayer`] + + Args: + config: Florence2LanguageConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = Florence2ScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = Florence2LearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([Florence2DecoderLayer(config) for _ in range(config.decoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class Florence2LanguageModel(Florence2LanguagePreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: Florence2LanguageConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = Florence2Encoder(config, self.shared) + self.decoder = Florence2Decoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqModelOutput]: + # different to other models, Florence2 automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) +# decoder_outputs = self.decoder( +# input_ids=decoder_input_ids, +# attention_mask=decoder_attention_mask, +# encoder_hidden_states=encoder_outputs[0], +# encoder_attention_mask=attention_mask, +# head_mask=decoder_head_mask, +# cross_attn_head_mask=cross_attn_head_mask, +# past_key_values=past_key_values, +# inputs_embeds=decoder_inputs_embeds, +# use_cache=use_cache, +# output_attentions=output_attentions, +# output_hidden_states=output_hidden_states, +# return_dict=return_dict, +# ) + + if not return_dict: + return encoder_outputs #decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + #last_hidden_state=decoder_outputs.last_hidden_state, + #past_key_values=decoder_outputs.past_key_values, + #decoder_hidden_states=decoder_outputs.hidden_states, + #decoder_attentions=decoder_outputs.attentions, + #cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel): + base_model_prefix = "model" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + + def __init__(self, config: Florence2LanguageConfig): + super().__init__(config) + self.model = Florence2LanguageModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + #lm_logits = self.lm_head(outputs[0]) + #lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + #masked_lm_loss = None + #if labels is not None: + # labels = labels.to(lm_logits.device) + # loss_fct = CrossEntropyLoss() + # masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + #if not return_dict: + # output = (lm_logits,) + outputs[1:] + # return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + #loss=masked_lm_loss, + #logits=lm_logits, + #past_key_values=outputs.past_key_values, + #decoder_hidden_states=outputs.decoder_hidden_states, + #decoder_attentions=outputs.decoder_attentions, + #cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + +@dataclass +class Florence2Seq2SeqLMOutput(ModelOutput): + """ + Base class for Florence-2 model's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, + num_image_tokens, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Florence2VisionLMOutput(ModelOutput): + """ + Base class for Florence-2 model's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + Args: + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, + num_image_tokens, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + +FLORENCE2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Florence2Config`] or [`Florence2VisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Florence-2 Model outputting raw hidden-states without any specific head on top.", + FLORENCE2_START_DOCSTRING, +) +class Florence2PreTrainedModel(PreTrainedModel): + config_class = Florence2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + + @property + def _supports_flash_attn_2(self): + """ + Retrieve language_model's attribute to check whether the model supports + Flash Attention 2 or not. + """ + return self.language_model._supports_flash_attn_2 + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA or not. + """ + return self.language_model._supports_sdpa + + +FLORENCE2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`Florence2Processor`] uses + [`CLIPImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +@add_start_docstrings( + """The FLORENCE2 vision model without any head""", + FLORENCE2_START_DOCSTRING, +) +class Florence2VisionModel(Florence2PreTrainedModel): + def __init__(self, config: Florence2VisionConfig): + super().__init__(config) + assert config.model_type == 'davit', 'only DaViT is supported for now' + spec = importlib.util.find_spec('timm') + if spec is None: + raise ImportError('timm is required to use DaViT model') + + self.vision_tower = DaViT.from_config(config=config) + + self.post_init() + + def forward(self, pixel_values): + if len(pixel_values.shape) == 4: + x = self.vision_tower.forward_features_unpool(pixel_values) + else: + raise ValueError(f'invalid image shape {pixel_values.shape}') + return x + + +@add_start_docstrings( + """The FLORENCE2 vision model with projection layer""", + FLORENCE2_START_DOCSTRING, +) +class Florence2VisionModelWithProjection(Florence2PreTrainedModel): + def __init__(self, config: Florence2VisionConfig): + super().__init__(config) + assert config.model_type == 'davit', 'only DaViT is supported for now' + self.vision_tower = DaViT.from_config(config=config) + + self._build_image_projection_layers(config) + + self.post_init() + + def _build_image_projection_layers(self, config): + image_dim_out = config.dim_embed[-1] + dim_projection = config.projection_dim + self.image_projection = nn.Parameter( + torch.empty(image_dim_out, dim_projection) + ) + self.image_proj_norm = nn.LayerNorm(dim_projection) + image_pos_embed_config = config.image_pos_embed + if image_pos_embed_config['type'] == 'learned_abs_2d': + self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( + embedding_dim=image_dim_out, + num_pos=image_pos_embed_config['max_pos_embeddings'] + ) + else: + raise NotImplementedError('Not implemented yet') + + self.image_feature_source = config.image_feature_source + + # temporal embedding + visual_temporal_embedding_config = config.visual_temporal_embedding + if visual_temporal_embedding_config['type'] == 'COSINE': + self.visual_temporal_embed = PositionalEmbeddingCosine1D( + embed_dim=image_dim_out, + max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] + ) + else: + raise NotImplementedError('Not implemented yet') + + def forward(self, pixel_values): + if len(pixel_values.shape) == 4: + batch_size, C, H, W = pixel_values.shape + T = 1 + x = self.vision_tower.forward_features_unpool(pixel_values) + else: + raise ValueError(f'invalid image shape {pixel_values.shape}') + + if self.image_pos_embed is not None: + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) + assert h * w == num_tokens, 'only support square feature maps for now' + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h*w, x.shape[-1]) + + if self.visual_temporal_embed is not None: + visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) + + x_feat_dict = {} + + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x + + temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) + x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x + + x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x_feat_dict['last_frame'] = x + + new_x = [] + for _image_feature_source in self.image_feature_source: + if _image_feature_source not in x_feat_dict: + raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) + new_x.append(x_feat_dict[_image_feature_source]) + + x = torch.cat(new_x, dim=1) + + x = x @ self.image_projection + x = self.image_proj_norm(x) + + + return x + + + +@add_start_docstrings( + """The FLORENCE2 model which consists of a vision backbone and a language model.""", + FLORENCE2_START_DOCSTRING, +) +class Florence2ForConditionalGeneration(Florence2PreTrainedModel): + def __init__(self, config: Florence2Config): + super().__init__(config) + assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now' + self.vision_tower = DaViT.from_config(config=config.vision_config) + # remove unused layers + del self.vision_tower.head + del self.vision_tower.norms + + self.vocab_size = config.vocab_size + self._attn_implementation = config._attn_implementation + self._build_image_projection_layers(config) + + language_model = Florence2LanguageForConditionalGeneration(config=config.text_config) + + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + self.language_model = language_model + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def _build_image_projection_layers(self, config): + image_dim_out = config.vision_config.dim_embed[-1] + dim_projection = config.vision_config.projection_dim + self.image_projection = nn.Parameter( + torch.empty(image_dim_out, dim_projection) + ) + self.image_proj_norm = nn.LayerNorm(dim_projection) + image_pos_embed_config = config.vision_config.image_pos_embed + if image_pos_embed_config['type'] == 'learned_abs_2d': + self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( + embedding_dim=image_dim_out, + num_pos=image_pos_embed_config['max_pos_embeddings'] + ) + else: + raise NotImplementedError('Not implemented yet') + + self.image_feature_source = config.vision_config.image_feature_source + + # temporal embedding + visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding + if visual_temporal_embedding_config['type'] == 'COSINE': + self.visual_temporal_embed = PositionalEmbeddingCosine1D( + embed_dim=image_dim_out, + max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] + ) + else: + raise NotImplementedError('Not implemented yet') + + def get_encoder(self): + return self.language_model.get_encoder() + + def get_decoder(self): + return self.language_model.get_decoder() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _encode_image(self, pixel_values): + if len(pixel_values.shape) == 4: + batch_size, C, H, W = pixel_values.shape + T = 1 + x = self.vision_tower.forward_features_unpool(pixel_values) + else: + raise ValueError(f'invalid image shape {pixel_values.shape}') + + if self.image_pos_embed is not None: + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) + assert h * w == num_tokens, 'only support square feature maps for now' + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h*w, x.shape[-1]) + + if self.visual_temporal_embed is not None: + visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) + + x_feat_dict = {} + + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x + + temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) + x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x + + x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x_feat_dict['last_frame'] = x + + new_x = [] + for _image_feature_source in self.image_feature_source: + if _image_feature_source not in x_feat_dict: + raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) + new_x.append(x_feat_dict[_image_feature_source]) + + x = torch.cat(new_x, dim=1) + + x = x @ self.image_projection + x = self.image_proj_norm(x) + + return x + + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds + ): + batch_size, image_token_length = image_features.size()[:-1] + device = image_features.device + image_attention_mask = torch.ones(batch_size, image_token_length, device=device) + + # task_prefix_embeds: [batch_size, padded_context_length, hidden_size] + # task_prefix_attention_mask: [batch_size, context_length] + if inputs_embeds is None: + return image_features, image_attention_mask + + task_prefix_embeds = inputs_embeds + task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device) + + if len(task_prefix_attention_mask.shape) == 3: + task_prefix_attention_mask = task_prefix_attention_mask[:, 0] + + # concat [image embeds, task prefix embeds] + inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1) + attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1) + + return inputs_embeds, attention_mask + + + @add_start_docstrings_to_model_forward(FLORENCE2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Florence2Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Florence2Seq2SeqLMOutput]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration + + >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large") + >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large") + + >>> prompt = "" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=100) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A green car parked in front of a yellow building." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + image_features = None + if inputs_embeds is None: + # 1. Extra the input embeddings + if input_ids is not None: + inputs_embeds = self.get_input_embeddings()(input_ids) + # 2. Merge text and images + if pixel_values is not None: + # (batch_size, num_image_tokens, hidden_size) + image_features = self._encode_image(pixel_values) + inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds) + + if inputs_embeds is not None: + attention_mask = attention_mask.to(inputs_embeds.dtype) + + outputs = self.language_model( + attention_mask=attention_mask, + labels=labels, + inputs_embeds=inputs_embeds, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + #logits = outputs.logits + #logits = logits.float() + #loss = outputs.loss + #if not return_dict: + # output = (logits,) + outputs[1:] + # return (loss,) + output if loss is not None else output + + return Florence2Seq2SeqLMOutput( + #loss=loss, + #logits=logits, + #past_key_values=outputs.past_key_values, + #decoder_hidden_states=outputs.decoder_hidden_states, + #decoder_attentions=outputs.decoder_attentions, + #cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + image_hidden_states=image_features + ) + + def generate( + self, + input_ids, + inputs_embeds=None, + pixel_values=None, + **kwargs + ): + + if inputs_embeds is None: + # 1. Extra the input embeddings + if input_ids is not None: + inputs_embeds = self.get_input_embeddings()(input_ids) + # 2. Merge text and images + if pixel_values is not None: + image_features = self._encode_image(pixel_values) + inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds) + + return self.language_model.generate( + input_ids=None, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + pixel_values=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self.language_model.shift_tokens_right(labels) + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) + + +@add_start_docstrings( + """The FLORENCE2 model which consists of a vision backbone and a language model (encoder-only).""", + FLORENCE2_START_DOCSTRING, +) +class Florence2VisionLanguageModel(Florence2PreTrainedModel): + def __init__(self, config: Florence2Config): + super().__init__(config) + + assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now' + self.vision_tower = DaViT.from_config(config=config.vision_config) + # remove unused layers + del self.vision_tower.head + del self.vision_tower.norms + + self.vocab_size = config.vocab_size + self._attn_implementation = config._attn_implementation + self._build_image_projection_layers(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.language_model = Florence2Encoder(config.text_config) + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def _build_image_projection_layers(self, config): + image_dim_out = config.vision_config.dim_embed[-1] + dim_projection = config.vision_config.projection_dim + self.image_projection = nn.Parameter( + torch.empty(image_dim_out, dim_projection) + ) + self.image_proj_norm = nn.LayerNorm(dim_projection) + image_pos_embed_config = config.vision_config.image_pos_embed + if image_pos_embed_config['type'] == 'learned_abs_2d': + self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( + embedding_dim=image_dim_out, + num_pos=image_pos_embed_config['max_pos_embeddings'] + ) + else: + raise NotImplementedError('Not implemented yet') + + self.image_feature_source = config.vision_config.image_feature_source + + # temporal embedding + visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding + if visual_temporal_embedding_config['type'] == 'COSINE': + self.visual_temporal_embed = PositionalEmbeddingCosine1D( + embed_dim=image_dim_out, + max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] + ) + else: + raise NotImplementedError('Not implemented yet') + + def get_encoder(self): + return self.language_model + + def get_input_embeddings(self): + return self.language_model.embed_tokens + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _encode_image(self, pixel_values): + if len(pixel_values.shape) == 4: + batch_size, C, H, W = pixel_values.shape + T = 1 + x = self.vision_tower.forward_features_unpool(pixel_values) + else: + raise ValueError(f'invalid image shape {pixel_values.shape}') + + if self.image_pos_embed is not None: + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) + assert h * w == num_tokens, 'only support square feature maps for now' + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h*w, x.shape[-1]) + + if self.visual_temporal_embed is not None: + visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) + + x_feat_dict = {} + + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x + + temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) + x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x + + x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x_feat_dict['last_frame'] = x + + new_x = [] + for _image_feature_source in self.image_feature_source: + if _image_feature_source not in x_feat_dict: + raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) + new_x.append(x_feat_dict[_image_feature_source]) + + x = torch.cat(new_x, dim=1) + + x = x @ self.image_projection + x = self.image_proj_norm(x) + + return x + + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds + ): + batch_size, image_token_length = image_features.size()[:-1] + device = image_features.device + image_attention_mask = torch.ones(batch_size, image_token_length, device=device) + + # task_prefix_embeds: [batch_size, padded_context_length, hidden_size] + # task_prefix_attention_mask: [batch_size, context_length] + if inputs_embeds is None: + return image_features, image_attention_mask + + task_prefix_embeds = inputs_embeds + task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device) + + if len(task_prefix_attention_mask.shape) == 3: + task_prefix_attention_mask = task_prefix_attention_mask[:, 0] + + # concat [image embeds, task prefix embeds] + inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1) + attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1) + + return inputs_embeds, attention_mask + + + @add_start_docstrings_to_model_forward(FLORENCE2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Florence2VisionLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + #decoder_input_ids: Optional[torch.LongTensor] = None, + #decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + #decoder_head_mask: Optional[torch.Tensor] = None, + #cross_attn_head_mask: Optional[torch.Tensor] = None, + #encoder_outputs: Optional[List[torch.FloatTensor]] = None, + #past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + #decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + #use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Florence2VisionLMOutput]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration + >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large") + >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large") + >>> prompt = "" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=100) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A green car parked in front of a yellow building." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + image_features = None + if inputs_embeds is None: + # 1. Extra the input embeddings + if input_ids is not None: + inputs_embeds = self.get_input_embeddings()(input_ids) + # 2. Merge text and images + if pixel_values is not None: + # (batch_size, num_image_tokens, hidden_size) + image_features = self._encode_image(pixel_values) + inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds) + + if inputs_embeds is not None: + attention_mask = attention_mask.to(inputs_embeds.dtype) + outputs = self.language_model( + #input_ids=input_ids, + attention_mask=attention_mask, + #labels=labels, + inputs_embeds=inputs_embeds, + #decoder_input_ids=decoder_input_ids, + #encoder_outputs=encoder_outputs, + #decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + #decoder_head_mask=decoder_head_mask, + #cross_attn_head_mask=cross_attn_head_mask, + #past_key_values=past_key_values, + #decoder_inputs_embeds=decoder_inputs_embeds, + #use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return outputs.last_hidden_state + + return Florence2VisionLMOutput( + encoder_last_hidden_state=outputs.last_hidden_state, + encoder_hidden_states=outputs.hidden_states, + encoder_attentions=outputs.attentions, + image_hidden_states=image_features + ) + + #def _reorder_cache(self, *args, **kwargs): + # return self.language_model._reorder_cache(*args, **kwargs) diff --git a/colpali_engine/models/florence2/colflor/processing_colflor.py b/colpali_engine/models/florence2/colflor/processing_colflor.py new file mode 100644 index 00000000..6563813a --- /dev/null +++ b/colpali_engine/models/florence2/colflor/processing_colflor.py @@ -0,0 +1,89 @@ +from typing import List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchFeature + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + +from .processing_florence2 import Florence2Processor + + +class ColFlorProcessor(BaseVisualRetrieverProcessor, Florence2Processor): + """ + Processor for ColPali. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def process_images( + self, + images: List[Image.Image], + ) -> BatchFeature: + """ + Process images for ColFlor2. + """ + texts_doc = [""] * len(images) + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=texts_doc, + images=images, + return_tensors="pt", + padding="longest", + ) + + new_part = torch.ones((batch_doc['attention_mask'].size()[0], 577)).to(batch_doc['attention_mask'].device) + batch_doc['full_attention_mask'] = torch.cat([new_part, batch_doc['attention_mask']], dim=1) + + return batch_doc + + def process_queries( + self, + queries: List[str], + max_length: int = 50, + suffix: Optional[str] = None, + ) -> BatchFeature: + """ + Process queries for ColFlor2. + """ + if suffix is None: + suffix = "" * 10 + texts_query: List[str] = [] + + for query in queries: + query = f"Query: {query}" + query += suffix # add suffix (pad tokens) + texts_query.append(query) + + batch_query = self.tokenizer( + text=texts_query, + return_tensors="pt", + padding="longest", + max_length= max_length + self.image_seq_length, + ) + + return batch_query + + def get_n_patches( + self, + image_size: Tuple[int, int], + patch_size: int, + ) -> Tuple[int, int]: + n_patches_x = self.image_processor.size["width"] // patch_size + n_patches_y = self.image_processor.size["height"] // patch_size + + return n_patches_x, n_patches_y + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) diff --git a/colpali_engine/models/florence2/colflor/processing_florence2.py b/colpali_engine/models/florence2/colflor/processing_florence2.py new file mode 100644 index 00000000..56a93f4b --- /dev/null +++ b/colpali_engine/models/florence2/colflor/processing_florence2.py @@ -0,0 +1,1087 @@ +# ruff: noqa +# coding=utf-8 +# Copyright 2024 Microsoft and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Florence-2. +""" + +import logging +import re +from typing import List, Optional, Union + +import numpy as np +import torch +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, is_valid_image +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import ( + PaddingStrategy, + PreTokenizedInput, + TextInput, + TruncationStrategy, +) +from transformers.utils import TensorType + +logger = logging.getLogger(__name__) + +# Copied from transformers.models.idefics2.processing_idefics2.is_url +def is_url(val) -> bool: + return isinstance(val, str) and val.startswith("http") + +# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url +def is_image_or_image_url(elem): + return is_url(elem) or is_valid_image(elem) + + +def _is_str_or_image(elem): + return isinstance(elem, (str)) or is_image_or_image_url(elem) + + +class Florence2Processor(ProcessorMixin): + r""" + Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor. + + [`Florence2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BartTokenizerFast`]. See the + [`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information. + + Args: + image_processor ([`CLIPImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`BartTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "CLIPImageProcessor" + tokenizer_class = ("BartTokenizer", "BartTokenizerFast") + + def __init__( + self, + image_processor=None, + tokenizer=None, + ): + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + if not hasattr(image_processor, "image_seq_length"): + raise ValueError("Image processor is missing an `image_seq_length` attribute.") + + self.image_seq_length = image_processor.image_seq_length + + tokens_to_add = { + 'additional_special_tokens': \ + tokenizer.additional_special_tokens + \ + ['', '', '', ''] + \ + [f'' for x in range(1000)] + \ + ['', '', '', '','', '', '', '', '', '', '', '', '', '', '', '', '', '', '', ''] + } + tokenizer.add_special_tokens(tokens_to_add) + + self.tasks_answer_post_processing_type = { + '': 'pure_text', + '': 'ocr', + '': 'pure_text', + '': 'pure_text', + '': 'pure_text', + '': 'description_with_bboxes', + '': 'description_with_bboxes', + '': "phrase_grounding", + '': 'polygons', + '': 'polygons', + '': 'description_with_bboxes_or_polygons', + '': 'pure_text', + '': 'pure_text', + '': 'pure_text', + '': 'bboxes' + } + + self.task_prompts_without_inputs = { + '': 'What is the text in the image?', + '': 'What is the text in the image, with regions?', + '': 'What does the image describe?', + '': 'Describe in detail what is shown in the image.', + '': 'Describe with a paragraph what is shown in the image.', + '': 'Locate the objects with category name in the image.', + '': 'Locate the objects in the image, with their descriptions.', + '': 'Locate the region proposals in the image.' + } + + self.task_prompts_with_input = { + '': "Locate the phrases in the caption: {input}", + '': 'Locate {input} in the image with mask', + '': 'What is the polygon mask of region {input}', + '': 'Locate {input} in the image.', + '': 'What is the region {input}?', + '': 'What does the region {input} describe?', + '': 'What text is in the region {input}?', + } + + self.post_processor = Florence2PostProcesser(tokenizer=tokenizer) + + + super().__init__(image_processor, tokenizer) + + def _construct_prompts(self, text): + # replace the task tokens with the task prompts if task token is in the text + prompts = [] + for _text in text: + # 1. fixed task prompts without additional inputs + for task_token, task_prompt in self.task_prompts_without_inputs.items(): + if task_token in _text: + assert _text == task_token, f"Task token {task_token} should be the only token in the text." + _text = task_prompt + break + # 2. task prompts with additional inputs + for task_token, task_prompt in self.task_prompts_with_input.items(): + if task_token in _text: + _text = task_prompt.format(input=_text.replace(task_token, '')) + break + prompts.append(_text) + return prompts + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + tokenize_newline_separately: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + do_resize: bool = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 + input_data_format: Optional[ + Union[str, "ChannelDimension"] # noqa: F821 + ] = None, + resample: "PILImageResampling" = None, # noqa: F821 + do_convert_rgb: bool = None, + do_thumbnail: bool = None, + do_align_long_axis: bool = None, + do_rescale: bool = None, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + tokenize_newline_separately (`bool`, defaults to `True`): + Adds a separately tokenized '\n' at the end of the prompt. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix` + is provided, the `input_ids` will also contain the suffix input ids. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **labels** -- Labels compatible with training if `suffix` is not None + """ + + return_token_type_ids = False + + if images is None: + raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.") + if text is None: + logger.warning_once( + "You are using Florence-2 without a text prompt." + ) + text = "" + + if isinstance(text, List) and isinstance(images, List): + if len(images) < len(text): + raise ValueError( + f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image." + ) + if _is_str_or_image(text): + text = [text] + elif isinstance(text, list) and _is_str_or_image(text[0]): + pass + + pixel_values = self.image_processor( + images, + do_resize=do_resize, + do_normalize=do_normalize, + return_tensors=return_tensors, + image_mean=image_mean, + image_std=image_std, + input_data_format=input_data_format, + data_format=data_format, + resample=resample, + do_convert_rgb=do_convert_rgb, + )["pixel_values"] + + if max_length is not None: + max_length -= self.image_seq_length # max_length has to account for the image tokens + + text = self._construct_prompts(text) + + inputs = self.tokenizer( + text, + return_tensors=return_tensors, + padding=padding, + max_length=max_length, + truncation=truncation, + return_token_type_ids=return_token_type_ids, + ) + + return_data = {**inputs, "pixel_values": pixel_values} + + if return_token_type_ids: + labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100) + return_data.update({"labels": labels}) + return BatchFeature(data=return_data) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Florence2 + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Florence2 + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Florence2 + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + def post_process_generation(self, text, task, image_size): + """ + Post-process the output of the model to each of the task outputs. + + Args: + text (`str`): The text to post-process. + task (`str`): The task to post-process the text for. + image_size (`Tuple[int, int]`): The size of the image. height x width. + """ + + task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text') + task_answer = self.post_processor( + text=text, + image_size=image_size, + parse_tasks=task_answer_post_processing_type, + )[task_answer_post_processing_type] + + if task_answer_post_processing_type == 'pure_text': + final_answer = task_answer + # remove the special tokens + final_answer = final_answer.replace('', '').replace('', '') + elif task_answer_post_processing_type in ['od', 'description_with_bboxes', 'bboxes']: + od_instances = task_answer + bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances] + labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances] + final_answer = {'bboxes': bboxes_od, 'labels': labels_od} + elif task_answer_post_processing_type in ['ocr']: + bboxes = [_od_instance['quad_box'] for _od_instance in task_answer] + labels = [str(_od_instance['text']) for _od_instance in task_answer] + final_answer = {'quad_boxes': bboxes, 'labels': labels} + elif task_answer_post_processing_type in ['phrase_grounding']: + bboxes = [] + labels = [] + for _grounded_phrase in task_answer: + for _bbox in _grounded_phrase['bbox']: + bboxes.append(_bbox) + labels.append(_grounded_phrase['cat_name']) + final_answer = {'bboxes': bboxes, 'labels': labels} + elif task_answer_post_processing_type in ['description_with_polygons', 'polygons']: + labels = [] + polygons = [] + for result in task_answer: + label = result['cat_name'] + _polygons = result['polygons'] + labels.append(label) + polygons.append(_polygons) + final_answer = {'polygons': polygons, 'labels': labels} + elif task_answer_post_processing_type in ['description_with_bboxes_or_polygons']: + bboxes = [] + bboxes_labels = [] + polygons = [] + polygons_labels = [] + for result in task_answer: + label = result['cat_name'] + if 'polygons' in result: + _polygons = result['polygons'] + polygons.append(_polygons) + polygons_labels.append(label) + else: + _bbox = result['bbox'] + bboxes.append(_bbox) + bboxes_labels.append(label) + final_answer = {'bboxes': bboxes, 'bboxes_labels': bboxes_labels, 'polygons': polygons, 'polygons_labels': polygons_labels} + else: + raise ValueError('Unknown task answer post processing type: {}'.format(task_answer_post_processing_type)) + + final_answer = { + task: final_answer} + return final_answer + +class BoxQuantizer(object): + def __init__(self, mode, bins): + self.mode = mode + self.bins = bins + + def quantize(self, boxes: torch.Tensor, size): + bins_w, bins_h = self.bins # Quantization bins. + size_w, size_h = size # Original image size. + size_per_bin_w = size_w / bins_w + size_per_bin_h = size_h / bins_h + xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1]. + + if self.mode == 'floor': + quantized_xmin = ( + xmin / size_per_bin_w).floor().clamp(0, bins_w - 1) + quantized_ymin = ( + ymin / size_per_bin_h).floor().clamp(0, bins_h - 1) + quantized_xmax = ( + xmax / size_per_bin_w).floor().clamp(0, bins_w - 1) + quantized_ymax = ( + ymax / size_per_bin_h).floor().clamp(0, bins_h - 1) + + elif self.mode == 'round': + raise NotImplementedError() + + else: + raise ValueError('Incorrect quantization type.') + + quantized_boxes = torch.cat( + (quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1 + ).int() + + return quantized_boxes + + def dequantize(self, boxes: torch.Tensor, size): + bins_w, bins_h = self.bins # Quantization bins. + size_w, size_h = size # Original image size. + size_per_bin_w = size_w / bins_w + size_per_bin_h = size_h / bins_h + xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1]. + + if self.mode == 'floor': + # Add 0.5 to use the center position of the bin as the coordinate. + dequantized_xmin = (xmin + 0.5) * size_per_bin_w + dequantized_ymin = (ymin + 0.5) * size_per_bin_h + dequantized_xmax = (xmax + 0.5) * size_per_bin_w + dequantized_ymax = (ymax + 0.5) * size_per_bin_h + + elif self.mode == 'round': + raise NotImplementedError() + + else: + raise ValueError('Incorrect quantization type.') + + dequantized_boxes = torch.cat( + (dequantized_xmin, dequantized_ymin, + dequantized_xmax, dequantized_ymax), dim=-1 + ) + + return dequantized_boxes + + +class CoordinatesQuantizer(object): + """ + Quantize coornidates (Nx2) + """ + + def __init__(self, mode, bins): + self.mode = mode + self.bins = bins + + def quantize(self, coordinates: torch.Tensor, size): + bins_w, bins_h = self.bins # Quantization bins. + size_w, size_h = size # Original image size. + size_per_bin_w = size_w / bins_w + size_per_bin_h = size_h / bins_h + assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)' + x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1]. + + if self.mode == 'floor': + quantized_x = (x / size_per_bin_w).floor().clamp(0, bins_w - 1) + quantized_y = (y / size_per_bin_h).floor().clamp(0, bins_h - 1) + + elif self.mode == 'round': + raise NotImplementedError() + + else: + raise ValueError('Incorrect quantization type.') + + quantized_coordinates = torch.cat( + (quantized_x, quantized_y), dim=-1 + ).int() + + return quantized_coordinates + + def dequantize(self, coordinates: torch.Tensor, size): + bins_w, bins_h = self.bins # Quantization bins. + size_w, size_h = size # Original image size. + size_per_bin_w = size_w / bins_w + size_per_bin_h = size_h / bins_h + assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)' + x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1]. + + if self.mode == 'floor': + # Add 0.5 to use the center position of the bin as the coordinate. + dequantized_x = (x + 0.5) * size_per_bin_w + dequantized_y = (y + 0.5) * size_per_bin_h + + elif self.mode == 'round': + raise NotImplementedError() + + else: + raise ValueError('Incorrect quantization type.') + + dequantized_coordinates = torch.cat( + (dequantized_x, dequantized_y), dim=-1 + ) + + return dequantized_coordinates + + +class Florence2PostProcesser(object): + r""" + Florence-2 post process for converting text prediction to various tasks results. + + Args: + config: A dict of configs. + tokenizer: A tokenizer for decoding text to spans. + sample config: + UNIFIED_POST_PROCESS: + # commom configs + NUM_BBOX_HEIGHT_BINS: 1000 + NUM_BBOX_WIDTH_BINS: 1000 + COORDINATES_HEIGHT_BINS: 1000 + COORDINATES_WIDTH_BINS: 1000 + # task specific configs, override the common configs + PRASE_TASKS: + - TASK_NAME: 'video_dense_caption' + PATTERN: 'r([a-zA-Z0-9 ]+)' + SCORE_MODE: 'avg_cat_name_scores' + NUM_BINS: 100 + - TASK_NAME: 'od' + PATTERN: 'r([a-zA-Z0-9 ]+)' + SCORE_MODE: 'avg_cat_name_scores' + + Returns: + parsed_dict (dict): A dict of parsed results. + """ + def __init__( + self, + tokenizer=None + ): + parse_tasks = [] + parse_task_configs = {} + config = self._create_default_config() + for task in config['PARSE_TASKS']: + parse_tasks.append(task['TASK_NAME']) + parse_task_configs[task['TASK_NAME']] = task + + self.config = config + self.parse_tasks = parse_tasks + self.parse_tasks_configs = parse_task_configs + + self.tokenizer = tokenizer + if self.tokenizer is not None: + self.all_special_tokens = set(self.tokenizer.all_special_tokens) + + self.init_quantizers() + self.black_list_of_phrase_grounding = self._create_black_list_of_phrase_grounding() + + def _create_black_list_of_phrase_grounding(self): + black_list = {} + + if 'phrase_grounding' in self.parse_tasks and self.parse_tasks_configs['phrase_grounding']['FILTER_BY_BLACK_LIST']: + black_list = set( + ['it', 'I', 'me', 'mine', + 'you', 'your', 'yours', + 'he', 'him', 'his', + 'she', 'her', 'hers', + 'they', 'them', 'their', 'theirs', + 'one', 'oneself', + 'we', 'us', 'our', 'ours', + 'you', 'your', 'yours', + 'they', 'them', 'their', 'theirs', + 'mine', 'yours', 'his', 'hers', 'its', + 'ours', 'yours', 'theirs', + 'myself', 'yourself', 'himself', 'herself', 'itself', + 'ourselves', 'yourselves', 'themselves', + 'this', 'that', + 'these', 'those', + 'who', 'whom', 'whose', 'which', 'what', + 'who', 'whom', 'whose', 'which', 'that', + 'all', 'another', 'any', 'anybody', 'anyone', 'anything', + 'each', 'everybody', 'everyone', 'everything', + 'few', 'many', 'nobody', 'none', 'one', 'several', + 'some', 'somebody', 'someone', 'something', + 'each other', 'one another', + 'myself', 'yourself', 'himself', 'herself', 'itself', + 'ourselves', 'yourselves', 'themselves', + 'the image', 'image', 'images', 'the', 'a', 'an', 'a group', + 'other objects', 'lots', 'a set', + ] + ) + + return black_list + + def _create_default_config(self): + config = { + 'NUM_BBOX_HEIGHT_BINS': 1000, + 'NUM_BBOX_WIDTH_BINS': 1000, + 'BOX_QUANTIZATION_MODE': 'floor', + 'COORDINATES_HEIGHT_BINS': 1000, + 'COORDINATES_WIDTH_BINS': 1000, + 'COORDINATES_QUANTIZATION_MODE': 'floor', + 'PARSE_TASKS': [ + { + 'TASK_NAME': 'od', + 'PATTERN': r'([a-zA-Z0-9 ]+)' + }, + { + 'TASK_NAME': 'ocr', + 'PATTERN': r'(.+?)', + 'AREA_THRESHOLD': 0.00 + }, + { + 'TASK_NAME': 'phrase_grounding', + 'FILTER_BY_BLACK_LIST': True + }, + { + 'TASK_NAME': 'pure_text', + }, + { + 'TASK_NAME': 'description_with_bboxes', + }, + { + 'TASK_NAME': 'description_with_polygons', + }, + { + 'TASK_NAME': 'polygons', + }, + { + 'TASK_NAME': 'bboxes', + }, + { + 'TASK_NAME': 'description_with_bboxes_or_polygons', + } + ] + } + + return config + + def init_quantizers(self): + # we have box_quantizer (od, grounding) and coordinates_quantizer (ocr, referring_segmentation) + num_bbox_height_bins = self.config.get('NUM_BBOX_HEIGHT_BINS', 1000) + num_bbox_width_bins = self.config.get('NUM_BBOX_WIDTH_BINS', 1000) + box_quantization_mode = self.config.get('BOX_QUANTIZATION_MODE', 'floor') + self.box_quantizer = BoxQuantizer( + box_quantization_mode, + (num_bbox_width_bins, num_bbox_height_bins), + ) + + num_bbox_height_bins = self.config['COORDINATES_HEIGHT_BINS'] if 'COORDINATES_HEIGHT_BINS' in self.config else self.config.get('NUM_BBOX_HEIGHT_BINS', 1000) + num_bbox_width_bins = self.config['COORDINATES_WIDTH_BINS'] if 'COORDINATES_WIDTH_BINS' in self.config else self.config.get('NUM_BBOX_WIDTH_BINS', 1000) + box_quantization_mode = self.config.get('COORDINATES_QUANTIZATION_MODE') if 'COORDINATES_QUANTIZATION_MODE' in self.config else self.config.get('BOX_QUANTIZATION_MODE', 'floor') + self.coordinates_quantizer = CoordinatesQuantizer( + box_quantization_mode, + (num_bbox_width_bins, num_bbox_height_bins), + ) + + def decode_with_spans(self, tokenizer, token_ids): + filtered_tokens = tokenizer.convert_ids_to_tokens( + token_ids, skip_special_tokens=False) + assert len(filtered_tokens) == len(token_ids) + + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separately for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + for token in filtered_tokens: + if token in self.all_special_tokens: + sub_texts.append(token) + else: + if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)): + sub_text = tokenizer.convert_tokens_to_string([token]) + elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)): + # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol + # Note: Do not strip sub_text as it may have functional whitespace + sub_text = token.replace('▁', ' ') + else: + raise ValueError(f'type {type(tokenizer)} not supported') + sub_texts.append(sub_text) + + text = '' + spans = [] + for sub_text in sub_texts: + span = (len(text), len(text) + len(sub_text)) # [start index, end index). + text += sub_text + spans.append(span) + + # Text format: + # 1. T5Tokenizer/T5TokenizerFast: + # " transplanting dog cat" + # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False) + # 2. BartTokenizer (need to double check): + # "transplanting dogcat" + # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False) + return text, spans + + def parse_od_from_text_and_spans( + self, + text, + pattern, + image_size, + phrase_centric=False + ): + parsed = list(re.finditer(pattern, text)) + + instances = [] + for i in range(len(parsed)): + # Prepare instance. + instance = {} + + if phrase_centric: + bbox_bins = [int(parsed[i].group(j)) for j in range(2, 6)] + else: + bbox_bins = [int(parsed[i].group(j)) for j in range(1, 5)] + instance['bbox'] = self.box_quantizer.dequantize( + boxes=torch.tensor(bbox_bins), + size=image_size + ).tolist() + + if phrase_centric: + instance['cat_name'] = parsed[i].group(1).lower().strip() + else: + instance['cat_name'] = parsed[i].group(5).lower().strip() + instances.append(instance) + + return instances + + def parse_ocr_from_text_and_spans(self, + text, + pattern, + image_size, + area_threshold=-1.0, + ): + bboxes = [] + labels = [] + text = text.replace('', '') + # ocr with regions + parsed = re.findall(pattern, text) + instances = [] + image_width, image_height = image_size + + for ocr_line in parsed: + ocr_content = ocr_line[0] + quad_box = ocr_line[1:] + quad_box = [int(i) for i in quad_box] + quad_box = self.coordinates_quantizer.dequantize( + torch.tensor(np.array(quad_box).reshape(-1, 2)), + size=image_size + ).reshape(-1).tolist() + + if area_threshold > 0: + x_coords = [i for i in quad_box[0::2]] + y_coords = [i for i in quad_box[1::2]] + + # apply the Shoelace formula + area = 0.5 * abs(sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1))) + + if area < (image_width * image_height) * area_threshold: + continue + + bboxes.append(quad_box) + labels.append(ocr_content) + instances.append({ + 'quad_box': quad_box, + 'text': ocr_content, + }) + return instances + + def parse_phrase_grounding_from_text_and_spans(self, text, pattern, image_size): + # ignore and + cur_span = 0 + if text.startswith(''): + cur_span += 3 + + text = text.replace('', '') + text = text.replace('', '') + text = text.replace('', '') + + pattern = r"([^<]+(?:){4,})" + phrases = re.findall(pattern, text) + + # pattern should be text pattern and od pattern + pattern = r'^\s*(.*?)(?=||||||' + + instances = [] + for pharse_text in phrases: + phrase_text_strip = pharse_text.replace('', '', 1) + phrase_text_strip = pharse_text.replace('', '', 1) + + if phrase_text_strip == '': + cur_span += len(pharse_text) + continue + + # Prepare instance. + instance = {} + + # parse phrase, get string + phrase = re.search(pattern, phrase_text_strip) + if phrase is None: + cur_span += len(pharse_text) + continue + + # parse bboxes by box_pattern + bboxes_parsed = list(re.finditer(box_pattern, pharse_text)) + if len(bboxes_parsed) == 0: + cur_span += len(pharse_text) + continue + + phrase = phrase.group() + # remove leading and trailing spaces + phrase = phrase.strip() + + if phrase in self.black_list_of_phrase_grounding: + cur_span += len(pharse_text) + continue + + # a list of list + bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed] + instance['bbox'] = self.box_quantizer.dequantize( + boxes=torch.tensor(bbox_bins), + size=image_size + ).tolist() + + # exclude non-ascii characters + phrase = phrase.encode('ascii',errors='ignore').decode('ascii') + instance['cat_name'] = phrase + + instances.append(instance) + + return instances + + def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False): + # temporary parse solution, split by '.' + # ignore and + + text = text.replace('', '') + text = text.replace('', '') + text = text.replace('', '') + + if allow_empty_phrase: + pattern = r"(?:(?:){4,})" + else: + pattern = r"([^<]+(?:){4,})" + phrases = re.findall(pattern, text) + + # pattern should be text pattern and od pattern + pattern = r'^\s*(.*?)(?=||||||' + + instances = [] + for pharse_text in phrases: + phrase_text_strip = pharse_text.replace('', '', 1) + phrase_text_strip = pharse_text.replace('', '', 1) + + if phrase_text_strip == '' and not allow_empty_phrase: + continue + + # parse phrase, get string + phrase = re.search(pattern, phrase_text_strip) + if phrase is None: + continue + + phrase = phrase.group() + # remove leading and trailing spaces + phrase = phrase.strip() + + # parse bboxes by box_pattern + bboxes_parsed = list(re.finditer(box_pattern, pharse_text)) + if len(bboxes_parsed) == 0: + continue + + # a list of list + bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed] + + bboxes = self.box_quantizer.dequantize( + boxes=torch.tensor(bbox_bins), + size=image_size + ).tolist() + + phrase = phrase.encode('ascii',errors='ignore').decode('ascii') + for _bboxes in bboxes: + # Prepare instance. + instance = {} + instance['bbox'] = _bboxes + # exclude non-ascii characters + instance['cat_name'] = phrase + instances.append(instance) + + return instances + + def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size, + allow_empty_phrase=False, + polygon_sep_token='', + polygon_start_token='', + polygon_end_token='', + with_box_at_start=False, + ): + + # ref_seg format: '<><><><><><>' + # ignore and + + text = text.replace('', '') + text = text.replace('', '') + text = text.replace('', '') + + if allow_empty_phrase: + pattern = rf"(?:(?:|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})" + else: + # [^<]+: This part matches one or more characters that are not the < symbol. + # The ^ inside the square brackets [] is a negation, meaning it matches anything except <. + # + pattern = rf"([^<]+(?:|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})" + phrases = re.findall(pattern, text) + + phrase_string_pattern = r'^\s*(.*?)(?=||||||)' + box_pattern = rf'((?:)+)(?:{re.escape(polygon_sep_token)}|$)' + + # one polygons instance is separated by polygon_start_token and polygon_end_token + polygons_instance_pattern = rf'{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}' + + instances = [] + for phrase_text in phrases: + + # exclude loc_\d+> + # need to get span if want to include category score + phrase_text_strip = re.sub(r'^loc_\d+>', '', phrase_text, count=1) + + # phrase = phrase.replace('', '') + # phrase = phrase.replace('poly>', '') + + if phrase_text_strip == '' and not allow_empty_phrase: + continue + + + # parse phrase, get string + phrase = re.search(phrase_string_pattern, phrase_text_strip) + if phrase is None: + continue + phrase = phrase.group() + # remove leading and trailing spaces + phrase = phrase.strip() + + # parse bboxes by box_pattern + + # split by polygon_start_token and polygon_end_token first using polygons_instance_pattern + if polygon_start_token in phrase_text and polygon_end_token in phrase_text: + polygons_instances_parsed = list(re.finditer(polygons_instance_pattern, phrase_text)) + else: + polygons_instances_parsed = [phrase_text] + + for _polygons_instances_parsed in polygons_instances_parsed: + # Prepare instance. + instance = {} + + # polygons_parsed= list(re.finditer(box_pattern, phrase_text)) + if isinstance(_polygons_instances_parsed, str): + polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed)) + else: + polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed.group(1))) + if len(polygons_parsed) == 0: + continue + + # a list of list (polygon) + bbox = [] + polygons = [] + for _polygon_parsed in polygons_parsed: + # group 1: whole ... + _polygon = _polygon_parsed.group(1) + # parse into list of int + _polygon = [int(_loc_parsed.group(1)) for _loc_parsed in re.finditer(r'', _polygon)] + if with_box_at_start and len(bbox) == 0: + if len(_polygon) > 4: + # no valid bbox prediction + bbox = _polygon[:4] + _polygon = _polygon[4:] + else: + bbox = [0, 0, 0, 0] + # abandon last element if is not paired + if len(_polygon) % 2 == 1: + _polygon = _polygon[:-1] + + # reshape into (n, 2) + _polygon = self.coordinates_quantizer.dequantize( + torch.tensor(np.array(_polygon).reshape(-1, 2)), + size=image_size + ).reshape(-1).tolist() + # reshape back + polygons.append(_polygon) + + instance['cat_name'] = phrase + instance['polygons'] = polygons + if len(bbox) != 0: + instance['bbox'] = self.box_quantizer.dequantize( + boxes=torch.tensor([bbox]), + size=image_size + ).tolist()[0] + + instances.append(instance) + + return instances + + def __call__( + self, + text=None, + image_size=None, + parse_tasks=None, + ): + """ + Args: + text: model outputs + image_size: (width, height) + parse_tasks: a list of tasks to parse, if None, parse all tasks. + + """ + if parse_tasks is not None: + if isinstance(parse_tasks, str): + parse_tasks = [parse_tasks] + for _parse_task in parse_tasks: + assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported' + + # sequence or text should be provided + assert text is not None, 'text should be provided' + + parsed_dict = { + 'text': text + } + + for task in self.parse_tasks: + if parse_tasks is not None and task not in parse_tasks: + continue + + pattern = self.parse_tasks_configs[task].get('PATTERN', None) + + if task == 'ocr': + instances = self.parse_ocr_from_text_and_spans( + text, + pattern=pattern, + image_size=image_size, + area_threshold=self.parse_tasks_configs[task].get('AREA_THRESHOLD', 0.0), + ) + parsed_dict['ocr'] = instances + elif task == 'phrase_grounding': + instances = self.parse_phrase_grounding_from_text_and_spans( + text, + pattern=pattern, + image_size=image_size, + ) + parsed_dict['phrase_grounding'] = instances + elif task == 'pure_text': + parsed_dict['pure_text'] = text + elif task == 'description_with_bboxes': + instances = self.parse_description_with_bboxes_from_text_and_spans( + text, + pattern=pattern, + image_size=image_size, + ) + parsed_dict['description_with_bboxes'] = instances + elif task == 'description_with_polygons': + instances = self.parse_description_with_polygons_from_text_and_spans( + text, + pattern=pattern, + image_size=image_size, + ) + parsed_dict['description_with_polygons'] = instances + elif task == 'polygons': + instances = self.parse_description_with_polygons_from_text_and_spans( + text, + pattern=pattern, + image_size=image_size, + allow_empty_phrase=True, + ) + parsed_dict['polygons'] = instances + elif task == 'bboxes': + instances = self.parse_description_with_bboxes_from_text_and_spans( + text, + pattern=pattern, + image_size=image_size, + allow_empty_phrase=True, + ) + parsed_dict['bboxes'] = instances + elif task == 'description_with_bboxes_or_polygons': + if '' in text: + # only support either polygons or bboxes, not both at the same time + instances = self.parse_description_with_polygons_from_text_and_spans( + text, + pattern=pattern, + image_size=image_size, + ) + else: + instances = self.parse_description_with_bboxes_from_text_and_spans( + text, + pattern=pattern, + image_size=image_size, + ) + parsed_dict['description_with_bboxes_or_polygons'] = instances + else: + raise ValueError("task {} is not supported".format(task)) + + return parsed_dict diff --git a/scripts/configs/flor2/train_colflor_model.yaml b/scripts/configs/flor2/train_colflor_model.yaml new file mode 100644 index 00000000..704c137d --- /dev/null +++ b/scripts/configs/flor2/train_colflor_model.yaml @@ -0,0 +1,64 @@ +config: + (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig + output_dir: !path ../../../models/colflor-trained + processor: + (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper + class_to_instanciate: !ext colpali_engine.models.ColFlorProcessor + pretrained_model_name_or_path: "./models/ColFlor-base" + use_fast: false + # max_length: 50 + + model: + (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper + class_to_instanciate: !ext colpali_engine.models.ColFlor + pretrained_model_name_or_path: "./models/ColFlor-base" + torch_dtype: !ext torch.bfloat16 + # torch_dtype: !ext torch.float32 + # use_cache: false +# device_map: "auto" +# quantization_config: +# (): transformers.BitsAndBytesConfig +# load_in_4bit: true +# bnb_4bit_quant_type: "nf4" +# bnb_4bit_compute_dtype: "bfloat16" +# bnb_4bit_use_double_quant: true + + dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set_detailed + eval_dataset_loader: !import ../data/test_data.yaml + + # max_length: 50 + run_eval: true + add_suffix: true + loss_func: + (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss + tr_args: + (): transformers.training_args.TrainingArguments + output_dir: null + overwrite_output_dir: true + num_train_epochs: 3 + per_device_train_batch_size: 16 + # 6 x 8 gpus = 48 batch size + # gradient_accumulation_steps: 4 + per_device_eval_batch_size: 16 + eval_strategy: "steps" + # dataloader_num_workers: 8 + # bf16: true + save_steps: 500 + # max_steps: 100 + logging_steps: 10 + eval_steps: 100 + warmup_steps: 100 + learning_rate: 5e-5 + save_total_limit: 1 + # optim: "paged_adamw_8bit" + +# peft_config: +# (): peft.LoraConfig +# r: 32 +# lora_alpha: 32 +# lora_dropout: 0.1 +# init_lora_weights: "gaussian" +# bias: "none" +# task_type: "FEATURE_EXTRACTION" +# target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' +# # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'