Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GKD] 0 loss #2217

Open
2 of 4 tasks
nivibilla opened this issue Oct 10, 2024 · 3 comments
Open
2 of 4 tasks

[GKD] 0 loss #2217

nivibilla opened this issue Oct 10, 2024 · 3 comments
Labels
🐛 bug Something isn't working 🏋 GKD Related to GKD

Comments

@nivibilla
Copy link

System Info

pip install git+https://github.com/huggingface/transformers.git
pip install tokenizers==0.20.0
pip install accelerate==0.34.2
pip install git+https://github.com/huggingface/trl.git
pip install datasets==3.0.1
pip install huggingface_hub==0.25.1
pip install peft==0.13.0
pip install databricks-cli==0.18.0
pip install bitsandbytes==0.44.1
pip install flash-attn==2.6.3 --no-build-isolation

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction


def qlora_gkd_train():
    import datasets
    import torch
    import transformers

    from trl import (
        GKDConfig,
        GKDTrainer,
        LogCompletionsCallback,
    )
    from peft import LoraConfig, TaskType, prepare_model_for_kbit_training

    import json

    with open('/local_disk0/training_config.json') as f:
        training_config = json.load(f)

    # # testing memory usage for batch size
    training_config['max_steps'] = 10
    # training_config['per_device_train_batch_size'] = 32
    print(json.dumps(training_config, indent=4))

    print("loading tokenizer")
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        training_config['teacher_model_name_or_path'],
        padding_side="left",
        truncation_side="left",
    )
    tokenizer.pad_token = tokenizer.eos_token

    print("loading dataset")
    train_dataset = datasets.load_from_disk('/local_disk0/train')

    # Model    
    torch_dtype = torch.bfloat16
    quant_storage_dtype = torch.bfloat16

    quantization_config = transformers.BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_quant_storage=quant_storage_dtype,
        )
    
    print("loading teacher model")
    teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
        training_config['student_model_name_or_path'],
        quantization_config=quantization_config,
        attn_implementation="flash_attention_2", # use sdpa, alternatively use "flash_attention_2"
        torch_dtype=quant_storage_dtype,
        device_map = "auto"
    )

    teacher_model = prepare_model_for_kbit_training(teacher_model)

    print("create student config")
    student_model_kwargs = dict(
        trust_remote_code=True,
        attn_implementation="flash_attention_2", # use sdpa, alternatively use "flash_attention_2"
        torch_dtype=quant_storage_dtype,
        use_cache=training_config['gradient_checkpointing'],
        device_map="auto",
        # quantization_config=quantization_config,
    )

    print("create student config")
    student_model_kwargs = dict(
        trust_remote_code=True,
        attn_implementation="flash_attention_2", # use sdpa, alternatively use "flash_attention_2"
        torch_dtype=quant_storage_dtype,
        use_cache=training_config['gradient_checkpointing'],
        device_map="auto",
        # quantization_config=quantization_config,
    )

    lora_config = LoraConfig(
        r=training_config['lora_r'],
        # target_modules="all-linear",
        target_modules=["q_proj", "k_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        lora_alpha=training_config['lora_alpha'],
        lora_dropout=0.05
    )

    training_arguments = GKDConfig(
        model_init_kwargs = student_model_kwargs,
        save_strategy='epoch',
        report_to='mlflow',
        # save_steps=training_config['save_steps'],
        ddp_find_unused_parameters=False,
        gradient_checkpointing=training_config['gradient_checkpointing'],
        per_device_train_batch_size=training_config['per_device_train_batch_size'],
        gradient_accumulation_steps=training_config['gradient_accumulation_steps'],
        num_train_epochs=training_config['num_train_epochs'],
        learning_rate=training_config['learning_rate'],
        warmup_ratio=training_config['warmup_ratio'],
        lr_scheduler_type="cosine",
        bf16=True,
        max_steps=training_config['max_steps'],
        logging_steps=training_config['logging_steps'],
        output_dir=training_config['output_dir'],
        gradient_checkpointing_kwargs={'use_reentrant':False},
        max_seq_length=training_config['max_seq_len'],
        use_liger=training_config['use_liger'],
        # optim="paged_adamw_8bit",
        dataset_text_field='prompt',
        packing=False,
        # # gkd params
        temperature=0.9,
        max_new_tokens=1024,
    )

    print("start training")
    trainer = GKDTrainer(
        model=training_config['student_model_name_or_path'],
        teacher_model=teacher_model,
        args=training_arguments,
        train_dataset=train_dataset,
        processing_class=tokenizer,
        peft_config=lora_config,
    )

    if training_config['resume']:
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

os.environ['ACCELERATE_BYPASS_DEVICE_MAP'] = "true"

qlora_gkd_train()
{
    "teacher_model_name_or_path": "/local_disk0/meta-llama/Llama-3.1-70B-Instruct",
    "student_model_name_or_path": "/local_disk0/meta-llama/Llama-3.2-3B-Instruct",
    "learning_rate": 1e-05,
    "per_device_train_batch_size": 4,
    "gradient_accumulation_steps": 1,
    "logging_steps": 1,
    "num_train_epochs": 15,
    "gradient_checkpointing": true,
    "use_peft": true,
    "lora_r": 64,
    "lora_alpha": 16,
    "max_seq_len": 1382,
    "use_liger": false,
    "warmup_ratio": 0.1,
    "resume": false,
    "max_steps": -1
}
image

Im seeing 0 loss and grad norm. Is this expected?

Expected behavior

No 0 loss?

@nivibilla
Copy link
Author

followup from #2215

@qgallouedec qgallouedec added 🐛 bug Something isn't working 🏋 GKD Related to GKD labels Oct 11, 2024
@kashif
Copy link
Collaborator

kashif commented Oct 11, 2024

@nivibilla what are the keys in your dataset, as currently the datacollator also checks if there is a prompt key to get the prompts only: https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L265

@nivibilla
Copy link
Author

image
I just have the prompt column with the name prompt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 GKD Related to GKD
Projects
None yet
Development

No branches or pull requests

3 participants