Skip to content

Commit

Permalink
Generate: visit non-llm prepare_inputs_for_generation (#34199)
Browse files Browse the repository at this point in the history
* tmp

* all visited

* test all

* Update src/transformers/models/moshi/modeling_moshi.py

Co-authored-by: Arthur <[email protected]>

* delete another one :D

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
gante and ArthurZucker authored Oct 17, 2024
1 parent 1d2c29f commit f51ac9e
Show file tree
Hide file tree
Showing 64 changed files with 140 additions and 1,134 deletions.
30 changes: 20 additions & 10 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,16 @@ def prepare_inputs_for_generation(
# 3. Prepare base model inputs
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and not self.config.is_encoder_decoder and cache_position[0] == 0:
model_inputs[input_ids_key] = None
model_inputs["inputs_embeds"] = inputs_embeds
if not self.config.is_encoder_decoder:
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs[input_ids_key] = None
model_inputs["inputs_embeds"] = inputs_embeds
else:
# `clone` calls in this function ensure a consistent stride. See #32227
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
model_inputs["inputs_embeds"] = None
else:
# `clone` calls in this function ensure a consistent stride. See #32227
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
model_inputs["inputs_embeds"] = None

# 4. Create missing `position_ids` on the fly
if (
Expand Down Expand Up @@ -428,10 +431,15 @@ def prepare_inputs_for_generation(

# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
# the 4D causal mask exists, it should be present in the base model (XXXModel class).
base_model = getattr(self, self.base_model_prefix)
causal_mask_creation_function = getattr(
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
)
base_model = getattr(self, self.base_model_prefix, None)
if base_model is None:
causal_mask_creation_function = getattr(
self, "_prepare_4d_causal_attention_mask_with_cache_position", None
)
else:
causal_mask_creation_function = getattr(
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
)
if causal_mask_creation_function is None:
logger.warning_once(
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
Expand All @@ -444,10 +452,12 @@ def prepare_inputs_for_generation(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.get_output_embeddings().weight.dtype,
dtype=self.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
config=self.config,
past_key_values=past_key_values,
)
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def set_input_embeddings(self, new_embeddings):
self.input_embeds_layer = new_embeddings

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# Overwritten -- bark has a model-specific hack
input_embeds = kwargs.get("input_embeds", None)

attention_mask = kwargs.get("attention_mask", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3020,23 +3020,6 @@ def forward(
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)

if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down Expand Up @@ -1644,6 +1645,8 @@ def prepare_inputs_for_generation(
use_cache=True,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.nn import CrossEntropyLoss

from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
Expand Down Expand Up @@ -166,7 +167,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start


@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class EncoderDecoderModel(PreTrainedModel):
class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
r"""
[`EncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
of the base model classes of the library as encoder and another one as decoder when created with the
Expand Down Expand Up @@ -666,20 +667,6 @@ def forward(
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_inputs.get("attention_mask"),
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": decoder_inputs.get("past_key_values"),
"use_cache": use_cache,
}
return input_dict

def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/fuyu/modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def prepare_inputs_for_generation(
image_patches_indices=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

if past_key_values:
input_ids = input_ids[:, -1:]

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/git/modeling_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,8 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# Overwritten -- `git` has special cache handling and doesn't support generating from `inputs_embeds` atm

# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
past_length = past_key_values.get_seq_length()
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down Expand Up @@ -1674,6 +1675,8 @@ def prepare_inputs_for_generation(
use_cache=None,
**kwargs,
):
# Overwritten -- custom processing based on `config.use_resampler`

model_inputs = {}
if image_hidden_states is not None:
if self.config.use_resampler:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,9 @@ def prepare_inputs_for_generation(
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
# precedence is moved to the model, we can remove this fn)

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,9 @@ def prepare_inputs_for_generation(
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
# precedence is moved to the model, we can remove this fn)

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/kosmos2/modeling_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,8 @@ def prepare_inputs_for_generation(
use_cache=None,
**model_kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,8 @@ def prepare_inputs_for_generation(
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

# Trigger the new behavior if we have more than image embeddings seq length tokens for images
legacy_processing = (
input_ids is not None
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,8 @@ def prepare_inputs_for_generation(
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

legacy_processing = (
input_ids is not None
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,8 @@ def prepare_inputs_for_generation(
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- extra custom processing

if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ def prepare_inputs_for_generation(
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- extra custom processing

if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ def prepare_inputs_for_generation(
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
Expand Down
Loading

0 comments on commit f51ac9e

Please sign in to comment.