-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First pass at integrating xP3 #60
base: main
Are you sure you want to change the base?
Changes from 1 commit
100f8cd
17b471e
e0b28a2
8aed084
ae113e4
ac984f5
91d43e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""Constants relate to xP3""" | ||
|
||
XP3_TRAIN_TASKS_SPLIT = { | ||
'xp3:glue_mrpc_equivalent': {'test': 1725, 'train': 3668, 'validation': 408}, | ||
'xp3:glue_mrpc_generate_paraphrase': {'test': 1725, 'train': 3668, 'validation': 408}, | ||
'xp3:glue_mrpc_generate_sentence': {'test': 1725, 'train': 3668, 'validation': 408}, | ||
'xp3:glue_mrpc_paraphrase': {'test': 1725, 'train': 3668, 'validation': 408}, | ||
'xp3:glue_mrpc_replace': {'test': 1725, 'train': 3668, 'validation': 408}, | ||
'xp3:glue_mrpc_same_thing': {'test': 1725, 'train': 3668, 'validation': 408}, | ||
'xp3:glue_mrpc_want_to_know': {'test': 1725, 'train': 3668, 'validation': 408}, | ||
} |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file is in progress atm. It's going to need a few more tweaks to get going. I went with the same approach at P3 (T0) as it seemed to make the most sense to me. If this is the wrong path to go down, happy to discuss. I separated the file out for hackability. It can go back into one of the |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Copyright 2022 The FLAN Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Configurations of all SeqIO tasks.""" | ||
|
||
import functools | ||
import os | ||
import copy | ||
|
||
from flan.v2 import constants | ||
from flan.v2 import constants_niv2 | ||
from flan.v2 import constants_xp3 | ||
from flan.v2 import constants_t0 | ||
from flan.v2 import postprocessors as post | ||
from flan.v2 import preprocessors as prep | ||
from flan.v2 import task_configs_v1 | ||
from flan.v2 import utils | ||
import frozendict | ||
|
||
import seqio | ||
import datasets | ||
from t5.evaluation import metrics as t5_metrics | ||
import tensorflow as tf | ||
import re | ||
from tqdm import tqdm | ||
import json | ||
|
||
|
||
DEFAULT_OUTPUT_FEATURES = constants.DEFAULT_OUTPUT_FEATURES | ||
TaskConfig = task_configs_v1.TaskConfig | ||
DEFAULT_SPM_PATH = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" | ||
DEFAULT_VOCAB = seqio.SentencePieceVocabulary(DEFAULT_SPM_PATH) | ||
|
||
XP3_TASK_CONFIGS = {} | ||
|
||
# Helper functions from our current xp3 script | ||
# Not sure if they will be needed | ||
|
||
def feature_to_spec(feature, length=False): | ||
if isinstance(feature, datasets.ClassLabel): | ||
return tf.TensorSpec(shape=() if not length else (None if length == -1 else length,), dtype=tf.int64) | ||
elif isinstance(feature, datasets.Value): | ||
return tf.TensorSpec( | ||
shape=() if not length else (None if length == -1 else length,), dtype=getattr(tf.dtypes, feature.dtype) | ||
) | ||
elif hasattr(feature, "dtype") and hasattr(feature, "shape"): | ||
return tf.TensorSpec(shape=feature.shape, dtype=feature.dtype) | ||
elif isinstance(feature, datasets.Sequence): | ||
return feature_to_spec(feature.feature, length=feature.length) | ||
elif isinstance(feature, list): | ||
return [feature_to_spec(f, length=length) for f in feature] | ||
elif isinstance(feature, dict): | ||
return {k: feature_to_spec(v, length=length) for k, v in feature.items()} | ||
else: | ||
raise ValueError(f"Unparseable feature type {type(feature)}") | ||
|
||
def hf_dataset_to_tf_dataset(dataset): | ||
return tf.data.Dataset.from_generator( | ||
dataset.__iter__, output_signature={k: feature_to_spec(v) for k, v in dataset.features.items()} | ||
) | ||
|
||
def get_tf_dataset(split, shuffle_files, dataset_name, subset_name, split_mapping,seed): | ||
# HF datasets does not support file-level shuffling | ||
del shuffle_files, seed | ||
print("we have reached the end of this func") | ||
dataset = datasets.load_dataset(dataset_name, subset_name) | ||
dataset = dataset[split_mapping[split]] | ||
print("we are now moving to the end of this func") | ||
# dataset = utils.apply_template(dataset, template) | ||
return hf_dataset_to_tf_dataset(dataset) | ||
|
||
def task_clean(text): | ||
# Clean the text according to allowed characters for a task name | ||
return re.sub(r"[^\w\d\._]+", "_", text) | ||
|
||
def get_task_name(dataset_name, subset_name): | ||
return task_clean(dataset_name + (f"_{subset_name}_" if subset_name is not None else "_")) | ||
|
||
# ========================= XP3 Training Sets =========================== | ||
for task_name in constants_xp3.XP3_TRAIN_TASK_SPLITS: | ||
subtask_id = task_name.split(":")[-1] | ||
ds_name = subtask_id.split("_")[0] | ||
subset_name = subtask_id.split("_")[1] | ||
if constants_t0.T0_TRAIN_TASK_METADATA[task_name]["in_flan"]: | ||
continue | ||
# Do not process T0 variants with negative examples. | ||
# We still keep the variants of these sets in a different format. | ||
# if "_score_eval" in subtask_id: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, I've only generated the splits for the training datasets. |
||
# continue | ||
# elif constants_t0.T0_TRAIN_TASK_METADATA[task_name][ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure how the metadata for the task was filled, this is something I'll need help with. |
||
# "task_type"] == "t0_question_answer": | ||
# preprocessors = [functools.partial(prep.t0, multiple_choice=False)] | ||
# if constants_t0.T0_TRAIN_TASK_METADATA[task_name]["seq_len"]["max"] == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Plan to use something similar and add to the metadata to do the custom preprocessing that happens in xP3. example fn |
||
# postprocessors = functools.partial(post.take_first_word) | ||
# else: | ||
# postprocessors = functools.partial(post.take_first_line) | ||
# elif constants_t0.T0_TRAIN_TASK_METADATA[task_name]["task_type"] in [ | ||
# "t0_multiple_choice", "t0_multiple_choice_separated_options" | ||
# ]: | ||
preprocessors = [functools.partial(prep.t0, multiple_choice=True)] | ||
postprocessors = None | ||
# Only include non-deterministic options if they aren't already hard-coded. | ||
# if constants_t0.T0_TRAIN_TASK_METADATA[task_name][ | ||
# "task_type"] == "t0_multiple_choice_separated_options": | ||
# preprocessors.append(prep.format_options) | ||
|
||
t0_metadata_prep = functools.partial(prep.add_source_info, | ||
task_name=subtask_id, task_source="P3") | ||
XP3_TASK_CONFIGS[task_name] = TaskConfig( | ||
source=seqio.TfdsDataSource( | ||
tfds_name=f"huggingface:{ds_name}/{subset_name}", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From an eye test I think this works: https://www.tensorflow.org/datasets/community_catalog/huggingface seems to have the datasets I checked randomly. There's some custom version picking in xP3 that will need to be brought over. I plan to put that into the metadata and use it here. |
||
# tfds_name=f"bigscience__p3/{subtask_id}", | ||
splits=["train"]), | ||
preprocessors=preprocessors + [t0_metadata_prep], | ||
postprocess_fn=postprocessors, | ||
metric_fns=[t5_metrics.accuracy], | ||
) | ||
|
||
|
||
# =========== Freeze task configs ========== # | ||
XP3_TASK_CONFIGS = frozendict.frozendict(XP3_TASK_CONFIGS) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import json | ||
import csv | ||
# pip install -q datasets | ||
import datasets | ||
# git clone -b tr13 https://github.com/Muennighoff/promptsource.git && cd promptsource; pip install -e . | ||
from promptsource.templates import DatasetTemplates | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd be happy to get rid of using promptsource here. It's just used to get all the possible prompt templates for a given dataset. I'm assuming FC has that builtin already, but wasn't sure where to look. |
||
|
||
|
||
# Set to False to use multilingual prompts e.g. 'id' for xcopa/id instead of 'en' | ||
USE_ENGLISH_PROMPTS = True | ||
|
||
# Some datasets have test sets with hidden labels which will still compile but only to noise | ||
# e.g. piqa test labels are all [-1] which still works on list indices resulting in | ||
# noise samples where the label is always the same | ||
SKIP_PROMPTS = { | ||
"common_gen": {"test": ["all"]}, | ||
"piqa": {"test": ["all"]}, | ||
"qasc": {"test": ["all"]}, | ||
"imdb": {"unsupervised": ["all"]}, | ||
"glue/qqp": {"test": ["all"]}, | ||
"qasc": {"test": ["all"]}, | ||
"cosmos_qa": {"test": [ | ||
"description_context_question_answer_text", | ||
"description_context_question_text", | ||
"description_context_question_answer_id", | ||
"context_answer_to_question", | ||
"context_description_question_answer_text", | ||
"context_description_question_answer_id", | ||
"context_question_description_answer_id", | ||
"context_description_question_text", | ||
"context_question_description_answer_text", | ||
"only_question_answer", | ||
"no_prompt_id", | ||
"context_question_description_text", | ||
"no_prompt_text", | ||
]}, | ||
"clue/tnews": {"test": ["all"]}, | ||
"clue/csl": {"test": ["all"]}, | ||
"clue/cmrc2018": {"test": ["generate_question", "in_an_exam", "answer_in_the_passage", "answer_following_question", "xp3longcontinue"]}, | ||
"clue/drcd": {"test": ["generate_question", "in_an_exam", "answer_in_the_passage", "answer_following_question", "xp3longcontinue"]}, | ||
"hellaswag": {"test": ["complete_first_then", "Topic of the context", "Open-ended completion", "Randomized prompts template", "Appropriate continuation - Yes or No", "Predict ending with hint", "Open-ended start", "Reversed appropriate continuation - Yes or No", "how_ends", "if_begins_how_continues"]}, | ||
} | ||
|
||
DS_TO_ENG_PROMPT = { | ||
"xcopa": "en", | ||
"Muennighoff/xstory_cloze": "en", | ||
"Muennighoff/xwinograd": "en", | ||
'GEM/wiki_lingua': 'en_en', # Contains correct language names | ||
'xnli': 'en', | ||
"paws-x": "en", | ||
"mlqa": "mlqa.en.en", | ||
"xquad": "xquad.en", | ||
"khalidalt/tydiqa-primary": "english", | ||
"khalidalt/tydiqa-goldp": "english", | ||
"pasinit/xlwic": "en", | ||
"GEM/xlsum": "english", | ||
"GEM/BiSECT": "en", | ||
} | ||
|
||
def get_dataset_splits(dataset_name, subset_name=None): | ||
info = datasets.get_dataset_infos(dataset_name) # gets a lot of metadata info on the dataset | ||
subset_name = subset_name or list(info.keys())[0] # subset name such as 'ak' | ||
return info[subset_name].splits # provides the relevant splits available | ||
|
||
def get_num_examples(dataset_splits): | ||
return {split: dataset_splits[split].num_examples for split in dataset_splits.keys()} | ||
|
||
def get_tasks_splits(ds): | ||
|
||
### GET DATASET & LANGUAGE ### | ||
|
||
ds_name, subset_name = ds | ||
dataset_splits = get_dataset_splits(ds_name, subset_name) | ||
if subset_name == "xlwic_en_zh": | ||
# Train set is en; val & test are zh | ||
del dataset_splits["train"] | ||
elif ds_name == "teven/code_docstring_corpus": | ||
# Bad quality split | ||
del dataset_splits["class_level"] | ||
|
||
### SELECT PROMPTS ### | ||
|
||
if subset_name is None: | ||
prompt_dataset_name = ds_name | ||
else: | ||
subset_name_prompt = subset_name | ||
if USE_ENGLISH_PROMPTS and ds_name in DS_TO_ENG_PROMPT: | ||
subset_name_prompt = DS_TO_ENG_PROMPT[ds_name] | ||
prompt_dataset_name = f"{ds_name}/{subset_name_prompt}" | ||
|
||
prompts = DatasetTemplates(prompt_dataset_name) | ||
|
||
### PROCESS ### | ||
splits = [] | ||
for t_name in prompts.all_template_names: | ||
ds_json = f"\'xp3:{ds_name}_{subset_name}_{t_name}\'".replace("/", "_").replace(" ", "_") + f":{get_num_examples(dataset_splits)}" | ||
splits.append(ds_json) | ||
|
||
return splits | ||
|
||
|
||
TRAIN_SPLITS = [] | ||
with open("xp3_train_datasets.csv", "r") as f: | ||
reader = csv.reader(f) | ||
next(reader) # skip header | ||
for ds in reader: | ||
splits = get_tasks_splits(ds) | ||
TRAIN_SPLITS = TRAIN_SPLITS + splits | ||
|
||
with open("xp3_train_splits.txt", "w") as f: | ||
for s in TRAIN_SPLITS: | ||
f.write(s) | ||
f.write(",\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incomplete, I have the script running but it's slow