Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] support disable shuffling #6388

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,6 @@ saves/
output/
wandb/
generated_predictions.jsonl

# unittest
dummy_dir/
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
)
disable_shuffling: bool = field(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
Expand Down
13 changes: 11 additions & 2 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -119,6 +119,13 @@ def create_scheduler(
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)

@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)

return super()._get_train_sampler()

@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Expand Down Expand Up @@ -266,7 +273,9 @@ def get_batch_loss_metrics(
return losses.mean(), metrics

@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Expand Down
9 changes: 7 additions & 2 deletions src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union

import torch
from transformers import Trainer
Expand Down Expand Up @@ -119,6 +119,9 @@ def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)

return Trainer._get_train_sampler(self)

@override
Expand Down Expand Up @@ -245,7 +248,9 @@ def get_batch_loss_metrics(
return losses, metrics

@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Expand Down
17 changes: 13 additions & 4 deletions src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

from types import MethodType
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch
from transformers import Trainer
from typing_extensions import override

Expand All @@ -24,8 +25,7 @@


if TYPE_CHECKING:
import torch
from transformers import ProcessorMixin
from transformers import PreTrainedModel, ProcessorMixin

from ...hparams import FinetuningArguments

Expand Down Expand Up @@ -70,7 +70,16 @@ def create_scheduler(
return super().create_scheduler(num_training_steps, optimizer)

@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)

return super()._get_train_sampler()

@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Expand Down
7 changes: 7 additions & 0 deletions src/llamafactory/train/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ def create_scheduler(
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)

@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)

return super()._get_train_sampler()

@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
Expand Down
13 changes: 11 additions & 2 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput

from ...hparams import FinetuningArguments
Expand Down Expand Up @@ -85,7 +85,16 @@ def create_scheduler(
return super().create_scheduler(num_training_steps, optimizer)

@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)

return super()._get_train_sampler()

@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = os.path.join("output", f"train_{stage}")
output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)


def test_export():
export_dir = os.path.join("output", "llama3_export")
export_dir = os.path.join("output", "dummy_dir/llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir)
81 changes: 81 additions & 0 deletions tests/train/test_sft_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass, field
from typing import Any, Dict, List

import pytest
from transformers import DataCollatorWithPadding

from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer
from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer


DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")

TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")

TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
"dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"overwrite_cache": False,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
}


@dataclass
class DataCollatorWithVerbose(DataCollatorWithPadding):
verbose_list: List[Dict[str, Any]] = field(default_factory=list)

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
self.verbose_list.extend(features)
batch = super().__call__(features)
return {k: v[:, :1] for k, v in batch.items()} # truncate input length


@pytest.mark.parametrize("disable_shuffling", [False, True])
def test_shuffle(disable_shuffling: bool):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
{"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS}
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorWithVerbose(tokenizer=tokenizer)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
**dataset_module,
**tokenizer_module,
)
trainer.train()
if disable_shuffling:
assert data_collator.verbose_list[0]["input_ids"] == dataset_module["train_dataset"][0]["input_ids"]
else:
assert data_collator.verbose_list[0]["input_ids"] != dataset_module["train_dataset"][0]["input_ids"]
Loading