From 6ba31a8a94bf7cfeaf59ffc3bc9e0b0cd3e25788 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 17 Oct 2024 17:01:56 -0400 Subject: [PATCH] Enable users to use their own loss functions + deal with prefetching for grad accum (#34198) * bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han --------- Co-authored-by: Daniel Han --- src/transformers/trainer.py | 290 ++++++++++++++++++++-------------- tests/trainer/test_trainer.py | 159 ++++++++++++++++++- 2 files changed, 325 insertions(+), 124 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7e4d1e5d267bb8..58a20f66f4e81b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -340,12 +340,16 @@ class Trainer: The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc). + compute_loss_func (`Callable`, *optional*): + A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated + batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, here is one using + the loss function from `transformers` compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the function needs to calculate and return the global summary - statistics rather than accumulating the batch-level statistics. + statistics rather than accumulating the batch-level statistics callbacks (List of [`TrainerCallback`], *optional*): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](callback). @@ -394,6 +398,7 @@ def __init__( Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), @@ -415,6 +420,7 @@ def __init__( f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " ) self.args = args + self.compute_loss_func = compute_loss_func # Seed must be set before instantiating the model when using model enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) @@ -2369,16 +2375,16 @@ def _inner_training_loop( total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): - epoch_iterator = train_dataloader - if hasattr(epoch_iterator, "set_epoch"): - epoch_iterator.set_epoch(epoch) + epoch_dataloader = train_dataloader + if hasattr(epoch_dataloader, "set_epoch"): + epoch_dataloader.set_epoch(epoch) # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: self._past = None steps_in_epoch = ( - len(epoch_iterator) + len(epoch_dataloader) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps ) @@ -2390,142 +2396,154 @@ def _inner_training_loop( rng_to_sync = False steps_skipped = 0 if steps_trained_in_current_epoch > 0: - epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) steps_skipped = steps_trained_in_current_epoch steps_trained_in_current_epoch = 0 rng_to_sync = True step = -1 - for step, inputs in enumerate(epoch_iterator): - total_batched_samples += 1 - - if self.args.include_num_input_tokens_seen: - main_input_name = getattr(self.model, "main_input_name", "input_ids") - if main_input_name not in inputs: - logger.warning( - "Tried to track the number of tokens seen, however the current model is " - "not configured properly to know what item is the input. To fix this, add " - "a `main_input_name` attribute to the model class you are using." - ) + epoch_iterator = iter(epoch_dataloader) + # We chunkify the epoch iterator into gradient accumulation steps `n` batches + remainder = num_examples % args.gradient_accumulation_steps + num_items_in_batch = None + if remainder == 0: + remainder = args.gradient_accumulation_steps + update_step = -1 + total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 + for _ in range(total_updates): + update_step += 1 + num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) + for inputs in batch_samples: + step += 1 + total_batched_samples += 1 + # Since we perform prefetching, we need to manually set sync_gradients + if total_batched_samples % args.gradient_accumulation_steps != 0: + self.accelerator.gradient_state._set_sync_gradients(False) else: - self.state.num_input_tokens_seen += ( - torch.sum( - self.accelerator.gather( - torch.tensor( - inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64 - ) - ) + self.accelerator.gradient_state._set_sync_gradients(True) + + if self.args.include_num_input_tokens_seen: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." ) - .cpu() - .item() - ) - if rng_to_sync: - self._load_rng_state(resume_from_checkpoint) - rng_to_sync = False - - # Skip past any already trained steps if resuming training - if steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - if steps_trained_progress_bar is not None: - steps_trained_progress_bar.update(1) - if steps_trained_in_current_epoch == 0: + else: + input_tokens = inputs[main_input_name].numel() + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) + self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item() + if rng_to_sync: self._load_rng_state(resume_from_checkpoint) - continue - elif steps_trained_progress_bar is not None: - steps_trained_progress_bar.close() - steps_trained_progress_bar = None - - if step % args.gradient_accumulation_steps == 0: - self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - - with self.accelerator.accumulate(model): - tr_loss_step = self.training_step(model, inputs) - - if ( - args.logging_nan_inf_filter - and not is_torch_xla_available() - and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) - ): - # if loss is nan or inf simply add the average of previous logged losses - tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) - else: - if tr_loss.device != tr_loss_step.device: - raise ValueError( - f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" - ) - tr_loss = tr_loss + tr_loss_step + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + with self.accelerator.accumulate(model): + tr_loss_step = self.training_step(model, inputs, num_items_in_batch) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + ) + tr_loss = tr_loss + tr_loss_step - self.current_flos += float(self.floating_point_ops(inputs)) + self.current_flos += float(self.floating_point_ops(inputs)) - is_last_step_and_steps_less_than_grad_acc = ( - steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch - ) + is_last_step_and_steps_less_than_grad_acc = ( + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + ) - if ( - total_batched_samples % args.gradient_accumulation_steps == 0 - or - # last step in epoch but step is always smaller than gradient_accumulation_steps - is_last_step_and_steps_less_than_grad_acc - ): - # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered - # in accelerate. So, explicitly enable sync gradients to True in that case. - if is_last_step_and_steps_less_than_grad_acc: + if ( + (total_batched_samples) % args.gradient_accumulation_steps == 0 + or + # last step in epoch but step is always smaller than gradient_accumulation_steps + is_last_step_and_steps_less_than_grad_acc + ): + # Since we perform prefetching, we need to manually set sync_gradients to True self.accelerator.gradient_state._set_sync_gradients(True) - # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0: - # deepspeed does its own clipping - - if is_sagemaker_mp_enabled() and args.fp16: - _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) - elif self.use_apex: - # Revert to normal clipping otherwise, handling Apex or full precision - _grad_norm = nn.utils.clip_grad_norm_( - amp.master_params(self.optimizer), - args.max_grad_norm, - ) - else: - _grad_norm = self.accelerator.clip_grad_norm_( - model.parameters(), - args.max_grad_norm, - ) - - if ( - is_accelerate_available() - and self.accelerator.distributed_type == DistributedType.DEEPSPEED - ): - grad_norm = model.get_global_grad_norm() - # In some cases the grad norm may not return a float - if hasattr(grad_norm, "item"): - grad_norm = grad_norm.item() - else: - grad_norm = _grad_norm + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + elif self.use_apex: + # Revert to normal clipping otherwise, handling Apex or full precision + _grad_norm = nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) - self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + if ( + is_accelerate_available() + and self.accelerator.distributed_type == DistributedType.DEEPSPEED + ): + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + else: + grad_norm = _grad_norm - self.optimizer.step() + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) - self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + self.optimizer.step() - optimizer_was_run = not self.accelerator.optimizer_step_was_skipped - if optimizer_was_run: - # Delay optimizer scheduling until metrics are generated - if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - self.lr_scheduler.step() + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) - model.zero_grad() - self.state.global_step += 1 - self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch - self.control = self.callback_handler.on_step_end(args, self.state, self.control) + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() - self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) - else: - self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) - if self.control.should_epoch_stop or self.control.should_training_stop: # PyTorch/XLA relies on the data loader to insert the mark_step for # each step. Since we are breaking the loop early, we need to manually # insert the mark_step here. + if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() + break + # We also need to break out of the nested loop + if self.control.should_epoch_stop or self.control.should_training_stop: if is_torch_xla_available(): xm.mark_step() break @@ -3514,7 +3532,9 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): return ctx_manager - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None + ) -> torch.Tensor: """ Perform a training step on a batch of inputs. @@ -3542,7 +3562,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs) + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) del inputs if ( @@ -3575,20 +3595,23 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: + loss *= self.args.gradient_accumulation_steps self.accelerator.backward(loss, **kwargs) return loss.detach() / self.args.gradient_accumulation_steps - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ - if self.label_smoother is not None and "labels" in inputs: + if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: labels = inputs.pop("labels") else: labels = None + # if num_items_in_batch is not None: + # inputs["num_items_in_batch"] = num_items_in_batch outputs = model(**inputs) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. @@ -3601,7 +3624,10 @@ def compute_loss(self, model, inputs, return_outputs=False): model_name = unwrapped_model.base_model.model._get_name() else: model_name = unwrapped_model._get_name() - if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + # User-defined compute_loss function + if self.compute_loss_func is not None: + loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) + elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) @@ -4993,3 +5019,21 @@ def _fsdp_qlora_plugin_updates(self): fsdp_plugin.set_mixed_precision( self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True ) + + def get_batch_samples(self, epoch_iterator, num_batches): + batch_samples = [] + num_items_in_batch = None + for _ in range(num_batches): + try: + batch_samples += [next(epoch_iterator)] + except StopIteration: + break + if len(batch_samples) > 0 and "labels" in batch_samples[0]: + # For now we don't support object detection + try: + num_items_in_batch = sum( + [data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] + ) + except TypeError: + pass + return batch_samples, num_items_in_batch diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index cbc93faf50e7a3..5c03355785d2b5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -42,6 +42,7 @@ AutoImageProcessor, AutoProcessor, AutoTokenizer, + DataCollatorForLanguageModeling, IntervalStrategy, PretrainedConfig, TrainerCallback, @@ -49,6 +50,7 @@ get_polynomial_decay_schedule_with_warmup, is_torch_available, logging, + set_seed, ) from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS from transformers.testing_utils import ( @@ -153,6 +155,19 @@ PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" +class StoreLossCallback(TrainerCallback): + """ + Simple callback to store the loss. + """ + + def __init__(self): + self.losses = [] + + def on_log(self, args, state, control, logs=None, **kwargs): + if "loss" in logs: + self.losses.append(logs["loss"]) + + class MockCudaOOMCallback(TrainerCallback): """ Simple callback to simulate CUDA OOM error if @@ -168,6 +183,26 @@ def on_step_end(self, args, state, control, **kwargs): raise RuntimeError("CUDA out of memory.") +def ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False): + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + shift_logits = shift_logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + if num_items_in_batch is None or disable_num_items_in_batch: + loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="mean") + else: + loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum") + loss = loss / num_items_in_batch + return loss + + class RegressionDataset: def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): np.random.seed(seed) @@ -438,6 +473,31 @@ def forward(self, input_x, labels=None, **kwargs): loss = nn.functional.mse_loss(y, labels) return (loss, y) + class BasicTextGenerationModel(nn.Module): + def __init__(self, vocab_size, hidden_size): + super().__init__() + self.embedding = nn.Embedding(vocab_size, hidden_size) + self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True) + self.fc = nn.Linear(hidden_size, vocab_size) + + def forward(self, input_ids, **kwargs): + embedded = self.embedding(input_ids) + lstm_out, _ = self.lstm(embedded) + logits = self.fc(lstm_out) + return logits + + def create_dummy_dataset_for_text_generation(vocab_size, seq_length, num_samples): + import datasets + import numpy as np + + # Create random input sequences + input_ids = np.random.randint(0, vocab_size, (num_samples, seq_length)) + + # Create a datasets.Dataset + dataset = datasets.Dataset.from_dict({"input_ids": input_ids, "labels": input_ids}) + + return dataset + class TstLayer(nn.Module): def __init__(self, hidden_size): super().__init__() @@ -676,8 +736,105 @@ def test_model_init(self): trainer.train() self.check_trained_model(trainer.model, alternate_seed=True) + @slow + def test_gradient_accumulation_loss_alignment(self): + set_seed(42) + import datasets + + model_name = "distilgpt2" + dataset_name = "wikitext" + dataset_config = "wikitext-2-raw-v1" + dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]") + dataset = dataset.train_test_split(test_size=0.2) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + def tokenize_function(examples): + return tokenizer(examples["text"]) + + tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) + + tokenizer.pad_token = tokenizer.eos_token + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + model = AutoModelForCausalLM.from_pretrained(model_name) + + def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False): + return ForCausalLMLoss( + logits["logits"], labels, vocab_size, num_items_in_batch, disable_num_items_in_batch + ) + + loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=False) + + base_loss_callback = StoreLossCallback() + + args_kwargs = { + "report_to": "none", + "logging_steps": 1, + "max_steps": 20, + "learning_rate": 3e-4, + "disable_tqdm": True, + } + + args = TrainingArguments( + "./generation", + **args_kwargs, + ) + trainer = Trainer( + model, + args, + train_dataset=tokenized_dataset["train"], + callbacks=[base_loss_callback], + compute_loss_func=loss_fn, + data_collator=data_collator, + ) + trainer.train() + + grad_accum_loss_callback = StoreLossCallback() + args = TrainingArguments( + "./generation", + **args_kwargs, + gradient_accumulation_steps=2, + per_device_train_batch_size=4, + ) + set_seed(42) + model = AutoModelForCausalLM.from_pretrained(model_name) + trainer = Trainer( + model, + args, + train_dataset=tokenized_dataset["train"], + callbacks=[grad_accum_loss_callback], + compute_loss_func=loss_fn, + data_collator=data_collator, + ) + trainer.train() + + set_seed(42) + model = AutoModelForCausalLM.from_pretrained(model_name) + broken_loss_callback = StoreLossCallback() + loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=True) + trainer = Trainer( + model, + args, + train_dataset=tokenized_dataset["train"], + callbacks=[broken_loss_callback], + compute_loss_func=loss_fn, + data_collator=data_collator, + ) + trainer.train() + + # Calculate the difference between the base loss and the grad_accum loss + diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)] + diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)] + # These should be quite close + for diff in diff_truth: + self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1") + + # These should be very off + for diff in diff_broken: + self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1") + def test_gradient_accumulation(self): - # Training with half the batch size but accumulation steps as 2 should give the same results. + # Training with half the batch size but accumulation steps as 2 should give the same training losses. trainer = get_regression_trainer( gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1 )