@@ -111,40 +111,27 @@ def prediction_step(
111
111
inputs : Dict [str , Union ["torch.Tensor" , Any ]],
112
112
prediction_loss_only : bool ,
113
113
ignore_keys : Optional [List [str ]] = None ,
114
+ ** gen_kwargs ,
114
115
) -> Tuple [Optional [float ], Optional ["torch.Tensor" ], Optional ["torch.Tensor" ]]:
115
116
r"""
116
117
Removes the prompt part in the generated tokens.
117
118
118
119
Subclass and override to inject custom behavior.
119
120
"""
120
- labels = inputs ["labels" ] if "labels" in inputs else None
121
- if self .args .predict_with_generate :
122
- assert self .processing_class .padding_side == "left" , "This method only accepts left-padded tensor."
123
- labels = labels .detach ().clone () if labels is not None else None # backup labels
124
- prompt_len , label_len = inputs ["input_ids" ].size (- 1 ), inputs ["labels" ].size (- 1 )
125
- if prompt_len > label_len :
126
- inputs ["labels" ] = self ._pad_tensors_to_target_len (inputs ["labels" ], inputs ["input_ids" ])
127
- if label_len > prompt_len : # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
128
- inputs ["labels" ] = inputs ["labels" ][:, :prompt_len ]
129
-
130
- loss , generated_tokens , _ = super ().prediction_step ( # ignore the returned labels (may be truncated)
131
- model , inputs , prediction_loss_only = prediction_loss_only , ignore_keys = ignore_keys
121
+ if self .args .predict_with_generate : # do not pass labels to model when generate
122
+ labels = inputs .pop ("labels" , None )
123
+ else :
124
+ labels = inputs .get ("labels" )
125
+
126
+ loss , generated_tokens , _ = super ().prediction_step (
127
+ model , inputs , prediction_loss_only = prediction_loss_only , ignore_keys = ignore_keys , ** gen_kwargs
132
128
)
133
129
if generated_tokens is not None and self .args .predict_with_generate :
134
- generated_tokens [:, :prompt_len ] = self .processing_class .pad_token_id
130
+ generated_tokens [:, : inputs [ "input_ids" ]. size ( - 1 ) ] = self .processing_class .pad_token_id
135
131
generated_tokens = generated_tokens .contiguous ()
136
132
137
133
return loss , generated_tokens , labels
138
134
139
- def _pad_tensors_to_target_len (self , src_tensor : "torch.Tensor" , tgt_tensor : "torch.Tensor" ) -> "torch.Tensor" :
140
- r"""
141
- Pads the tensor to the same length as the target tensor.
142
- """
143
- assert self .processing_class .pad_token_id is not None , "Pad token is required."
144
- padded_tensor = self .processing_class .pad_token_id * torch .ones_like (tgt_tensor )
145
- padded_tensor [:, - src_tensor .shape [- 1 ] :] = src_tensor # adopt left-padding
146
- return padded_tensor .contiguous () # in contiguous memory
147
-
148
135
def save_predictions (
149
136
self , dataset : "Dataset" , predict_results : "PredictionOutput" , skip_special_tokens : bool = True
150
137
) -> None :
0 commit comments