Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: The NeuronTrainer only accept NeuronTrainingArguments, but <class 'optimum.neuron.training_args.Seq2SeqNeuronTrainingArguments'> was provided. #693

Open
2 of 4 tasks
industrialeaf opened this issue Sep 6, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@industrialeaf
Copy link

System Info

AWS EC2 instance: trn1.32xlarge

Platform:

- Platform: Linux-5.15.0-1031-aws-x86_64-with-glibc2.35
- Python version: 3.11.9


Python packages:

- `optimum-neuron` version: 0.0.24
- `neuron-sdk` version: 2.19.1
- `optimum` version: 1.20.0
- `transformers` version: 4.41.1
- `huggingface_hub` version: 0.24.6
- `torch` version: 2.1.2+cu121
- `aws-neuronx-runtime-discovery` version: 2.9
- `libneuronxla` version: 2.0.2335
- `neuronx-cc` version: 2.14.213.0
- `neuronx-distributed` version: 0.8.0
- `neuronx-hwm` version: NA
- `torch-neuronx` version: 2.1.2.2.2.0
- `torch-xla` version: 2.1.3
- `transformers-neuronx` version: NA

Who can help?

@michaelbenayoun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

Error message:

Traceback (most recent call last):
  File "/home/ubuntu/projects/seq2seq/train_t5_small.py", line 48, in <module>
    trainer = Seq2SeqNeuronTrainer(
              ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/py311/lib/python3.11/site-packages/optimum/neuron/trainers.py", line 144, in __init__
    raise ValueError(
ValueError: The NeuronTrainer only accept NeuronTrainingArguments, but <class 'optimum.neuron.training_args.Seq2SeqNeuronTrainingArguments'> was provided.

Minimal example to reproduce:

Run the following script with torchrun train.py.

from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from optimum.neuron import Seq2SeqNeuronTrainer, Seq2SeqNeuronTrainingArguments
from optimum.neuron.distributed import lazy_load_for_parallelism


# Load dataset
dataset = load_dataset("samsum")

# Load tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-small")

# Preprocess the data
def preprocess_function(examples):
    inputs = ["summarize: " + doc for doc in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length')

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=150, truncation=True, padding='max_length')

    model_inputs["labels"] = labels["input_ids"]
    print("keys", model_inputs.keys())
    print("len labels", len(model_inputs['labels']))
    print("len inpids", len(model_inputs['input_ids']))
    print("len attmsk", len(model_inputs['attention_mask']))
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)

# Define training arguments
training_args = Seq2SeqNeuronTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=False,  # should be false since we don't provide a generation_config
)

# Load model
with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

# Initialize the trainer
trainer = Seq2SeqNeuronTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

Expected behavior

The NeuronTrainer accepts Seq2SeqNeuronTrainingArguments.

I have a workaround going where I have patched these lines to accept Seq2SeqNeuronTrainingArguments:

if not isinstance(self.args, NeuronTrainingArguments) and not isinstance(self.args, Seq2SeqNeuronTrainingArguments):
    raise ValueError(
         f"The NeuronTrainer only accepts NeuronTrainingArguments and Seq2SeqNeuronTrainingArguments, but {type(self.args)} was provided."
    )
@industrialeaf industrialeaf added the bug Something isn't working label Sep 6, 2024
@michaelbenayoun michaelbenayoun self-assigned this Sep 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants