-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from prajjwal1/meta
added Meta Trainer, tests, example
- Loading branch information
Showing
13 changed files
with
544 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,9 @@ coverage.xml | |
*.mo | ||
*.pot | ||
|
||
wandb/ | ||
*.lock | ||
cache* | ||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
style: | ||
isort --recursive --multi-line=3 --trailing-comma --force-grid-wrap=0 --use-parentheses --line-width=88 fluence tests | ||
black fluence tests | ||
isort --recursive --multi-line=3 --trailing-comma --force-grid-wrap=0 --use-parentheses --line-width=88 fluence tests examples | ||
black fluence tests examples | ||
|
||
quality: | ||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics | ||
black --check fluence tests | ||
black --check fluence tests examples | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
### Running MAML | ||
|
||
Usage | ||
```bash | ||
python3 examples/run_maml_glue.py --model_name_or_path bert-base-uncased --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 1 --learning_rate 2e-5 --output_dir /home/nlp/experiments/fluence_exp/ --overwrite_output_dir --per_device_eval_batch_size 4096 --data_dir $GLUE_DIR --train_task mrpc --eval_task sst-2 --save_steps=10000 --num_train_epochs=1 --output_file_name check --eval_method every_2 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import dataclasses | ||
import logging | ||
import os | ||
from dataclasses import dataclass, field | ||
from typing import Callable, Dict, List, Optional, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data.dataloader import DataLoader | ||
from torch.utils.data.dataset import Dataset | ||
from tqdm import tqdm, trange | ||
from transformers import ( | ||
AutoConfig, | ||
AutoModelForSequenceClassification, | ||
AutoTokenizer, | ||
EvalPrediction, | ||
GlueDataset, | ||
) | ||
from transformers import GlueDataTrainingArguments as DataTrainingArguments | ||
from transformers import ( | ||
HfArgumentParser, | ||
Trainer, | ||
TrainingArguments, | ||
default_data_collator, | ||
glue_compute_metrics, | ||
glue_output_modes, | ||
glue_tasks_num_labels, | ||
set_seed, | ||
) | ||
from transformers.data.data_collator import DataCollator | ||
|
||
from fluence.meta import MetaDataset, MetaTrainer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class ModelArguments: | ||
""" | ||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. | ||
""" | ||
|
||
model_name_or_path: str = field( | ||
metadata={ | ||
"help": ( | ||
"Path to pretrained model or model identifier from" | ||
" huggingface.co/models" | ||
) | ||
} | ||
) | ||
config_name: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "Pretrained config name or path if not the same as model_name" | ||
}, | ||
) | ||
tokenizer_name: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "Pretrained tokenizer name or path if not the same as model_name" | ||
}, | ||
) | ||
cache_dir: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": ( | ||
"Where do you want to store the pretrained models downloaded from s3" | ||
) | ||
}, | ||
) | ||
|
||
|
||
@dataclass | ||
class MetaArguments(TrainingArguments): | ||
train_task: Optional[str] = field( | ||
default=None, metadata={"help": "Support dataset"} | ||
) | ||
eval_task: Optional[str] = field(default=None, metadata={"help": "Query dataset"}) | ||
data_dir: Optional[str] = field(default=None) | ||
inner_learning_rate: float = field(default=2e-5) | ||
learning_rate: Optional[float] = field(default=2e-5) # Outer | ||
max_len: int = field(default=80) | ||
eval_method: Optional[str] = field(default=None) | ||
max_seq_length: int = field( | ||
default=128, | ||
metadata={ | ||
"help": ( | ||
"The maximum total input sequence length after tokenization. Sequences" | ||
" longer than this will be truncated, sequences shorter will be padded." | ||
) | ||
}, | ||
) | ||
overwrite_cache: bool = field( | ||
default=False, | ||
metadata={"help": "Overwrite the cached training and evaluation sets"}, | ||
) | ||
output_file_name: Optional[str] = field(default="results") | ||
|
||
|
||
def main(): | ||
parser = HfArgumentParser((ModelArguments, MetaArguments)) | ||
model_args, training_args = parser.parse_args_into_dataclasses() | ||
if ( | ||
os.path.exists(training_args.output_dir) | ||
and os.listdir(training_args.output_dir) | ||
and training_args.do_train | ||
and not training_args.overwrite_output_dir | ||
): | ||
raise ValueError( | ||
f"Output directory ({training_args.output_dir}) already exists and is not" | ||
" empty. Use --overwrite_output_dir to overcome." | ||
) | ||
|
||
# Set seed | ||
set_seed(training_args.seed) | ||
|
||
# Setup logging | ||
logging.basicConfig( | ||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
datefmt="%m/%d/%Y %H:%M:%S", | ||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, | ||
) | ||
logger.warning( | ||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s," | ||
" 16-bits training: %s", | ||
training_args.local_rank, | ||
training_args.device, | ||
training_args.n_gpu, | ||
bool(training_args.local_rank != -1), | ||
training_args.fp16, | ||
) | ||
logger.info("Training/evaluation parameters %s", training_args) | ||
set_seed(training_args.seed) | ||
|
||
try: | ||
num_labels = glue_tasks_num_labels[training_args.train_task] | ||
output_mode = glue_output_modes[training_args.train_task] | ||
except KeyError: | ||
raise ValueError("Task not found: %s" % (training_args.train_task)) | ||
|
||
config = AutoConfig.from_pretrained( | ||
model_args.config_name | ||
if model_args.config_name | ||
else model_args.model_name_or_path, | ||
num_labels=num_labels, | ||
finetuning_task=training_args.train_task, | ||
cache_dir=model_args.cache_dir, | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
model_args.tokenizer_name | ||
if model_args.tokenizer_name | ||
else model_args.model_name_or_path, | ||
cache_dir=model_args.cache_dir, | ||
) | ||
model = AutoModelForSequenceClassification.from_pretrained( | ||
model_args.model_name_or_path, | ||
from_tf=bool(".ckpt" in model_args.model_name_or_path), | ||
config=config, | ||
cache_dir=model_args.cache_dir, | ||
) | ||
|
||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: | ||
def compute_metrics_fn(p: EvalPrediction) -> Dict: | ||
if output_mode == "classification": | ||
preds = np.argmax(p.predictions, axis=1) | ||
elif output_mode == "regression": | ||
preds = np.squeeze(p.predictions) | ||
return glue_compute_metrics(training_args.task_name, preds, p.label_ids) | ||
|
||
return compute_metrics_fn | ||
|
||
data_dir = { | ||
"mrpc": training_args.data_dir + "/MRPC", | ||
"sst-2": training_args.data_dir + "/SST-2", | ||
"cola": training_args.data_dir + "/Cola", | ||
"sts-b": training_args.data_dir + "/STS-B", | ||
} | ||
|
||
training_args.task_name = training_args.train_task | ||
training_args.data_dir = data_dir[training_args.task_name] | ||
train_dataset = GlueDataset(training_args, tokenizer=tokenizer) | ||
meta_dataset = MetaDataset(train_dataset) | ||
training_args.task_name = training_args.eval_task | ||
training_args.data_dir = data_dir[training_args.task_name] | ||
eval_dataset = GlueDataset(training_args, tokenizer=tokenizer, mode="dev") | ||
|
||
meta_trainer = MetaTrainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=meta_dataset, | ||
eval_dataset=eval_dataset, | ||
train_data_collator=torch.utils.data._utils.collate.default_collate, | ||
eval_data_collator=default_data_collator, | ||
) | ||
|
||
meta_trainer.train() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.1.4" | ||
__version__ = "0.1.5" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .meta_args import MetaArguments | ||
from .meta_dataset import MetaDataset | ||
from .meta_trainer import MetaTrainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Callable, Dict, List, Optional, Tuple | ||
|
||
from transformers import TrainingArguments | ||
|
||
|
||
@dataclass | ||
class MetaArguments(TrainingArguments): | ||
train_task: List = field(default=None, metadata="Support dataset") | ||
eval_task: List = field(default=None, metadata="Query dataset") | ||
data_dir: str = field(default=None) | ||
inner_learning_rate: float = field(default=1e-3) | ||
outer_learning_rate: float = field(default=2e-5) | ||
max_len: int = field(default=80) | ||
eval_method: str = field(default=None) | ||
max_seq_length: int = field( | ||
default=128, | ||
metadata={ | ||
"help": ( | ||
"The maximum total input sequence length after tokenization. Sequences" | ||
" longer than this will be truncated, sequences shorter will be padded." | ||
) | ||
}, | ||
) | ||
overwrite_cache: bool = field( | ||
default=False, | ||
metadata={"help": "Overwrite the cached training and evaluation sets"}, | ||
) |
Oops, something went wrong.