Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep target_for_fewshot_sorting for another purpose #241

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions src/lighteval/tasks/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def bbh_harness(line, task_name: str = None):
query=query,
choices=choices,
gold_index=correct_index,
target_for_fewshot_sorting=choices,
instruction=line.get("task_prefix", None),
)

Expand All @@ -196,7 +195,6 @@ def bbh_lighteval(line, task_name: str = None):
query=query,
choices=LETTER_INDICES[: len(line["choices"])],
gold_index=line["target_idx"],
target_for_fewshot_sorting=LETTER_INDICES[: len(line["choices"])],
instruction=line.get("task_prefix", None),
)

Expand All @@ -205,9 +203,8 @@ def bbh(line, instruction, choices, task_name: str = None):
return Doc(
task_name=task_name,
query=f"{instruction}Q: {line['input']}\nA:",
choices=choices,
choices=[(" " if line["__few_shots"] else "") + c for c in choices],
gold_index=choices.index(line["target"]),
target_for_fewshot_sorting=[f" {c}" for c in choices],
instruction=instruction,
)

Expand Down Expand Up @@ -793,10 +790,9 @@ def hellaswag_helm(line, task_name: str = None):
return Doc(
task_name=task_name,
query=query,
choices=[" " + i for i in LETTER_INDICES[: len(line["endings"])]],
choices=[" " + i for i in LETTER_INDICES[: len(line["endings"])]] + ([""] if line["__few_shot"] else []),
gold_index=gold_ix, # -1 for test,
instruction="The following are multiple choice questions (with answers) about common sense.\n\n",
target_for_fewshot_sorting=line["endings"][gold_ix] if gold_ix > -1 else "",
specific={
"label_to_choices": {f" {key}": choice for key, choice in zip(LETTER_INDICES, line["endings"])},
},
Expand Down Expand Up @@ -1352,7 +1348,6 @@ def mmlu(line, topic, task_name: str = None):
choices=[" A", " B", " C", " D"] if is_few_shots else ["A", "B", "C", "D"],
gold_index=gold_ix,
instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix],
)


Expand All @@ -1373,7 +1368,6 @@ def custom_mmlu_thom(line, task_name: str = None):
choices=[" A", " B", " C", " D"] if is_few_shots else ["A", "B", "C", "D"],
gold_index=gold_ix,
instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix],
)


Expand Down Expand Up @@ -1613,15 +1607,13 @@ def mmlu_harness(line, task_name: str = None):
query += "Answer:"

gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
"__few_shots" in line and line["__few_shots"] is True # We are adding few shots

return Doc(
task_name=task_name,
query=query,
choices=[" A", " B", " C", " D"],
gold_index=gold_ix,
instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix],
)


Expand All @@ -1632,14 +1624,14 @@ def mmlu_helm(line, task_name: str = None):
query += "\nAnswer:"

gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
is_few_shots = line.get("__few_shots", False) # We are adding few shots

return Doc(
task_name=task_name,
query=query,
choices=[" A", " B", " C", " D"],
choices=[" A", " B", " C", " D"] if not is_few_shots else ["A", "B", "C", "D"], # specific to HELM evals
gold_index=gold_ix,
instruction=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\n",
target_for_fewshot_sorting=line["choices"][gold_ix], # specific to HELM evals
)


Expand Down Expand Up @@ -1804,7 +1796,6 @@ def openbookqa_helm(line, task_name: str = None):
choices=["A", "B", "C", "D", "E"],
gold_index=gold_ix,
instruction="The following are multiple choice questions (with answers) about common sense.\n",
target_for_fewshot_sorting=line["choices"]["text"][gold_ix], # specific to HELM evals
)


Expand All @@ -1825,14 +1816,12 @@ def piqa_helm(line, task_name: str = None):
query += "Answer: "

gold_ix = int(line["label"])

return Doc(
task_name=task_name,
query=query,
choices=["A", "B"],
gold_index=gold_ix,
instruction="The following are multiple choice questions (with answers) about common sense.\n",
target_for_fewshot_sorting=[line["sol1"], line["sol2"]][gold_ix],
)


Expand Down Expand Up @@ -1865,13 +1854,11 @@ def pubmed_qa_helm(line, task_name: str = None):
)
query += f"\n\nQuestion: {line['question']}\nAnswer: "
gold_ix = ["yes", "no", "maybe"].index(line["final_decision"])

return Doc(
task_name=task_name,
query=query,
choices=["A", "B", "C"],
gold_index=gold_ix,
target_for_fewshot_sorting=["yes", "no", "maybe"][gold_ix],
)


Expand Down Expand Up @@ -2251,13 +2238,11 @@ def truthful_qa_helm(line, task_name: str = None):
query = f"Question: {line['question']}\n"
query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
query += "Answer:"

return Doc(
task_name=task_name,
query=query,
choices=LETTER_INDICES[: len(line["choices"])],
gold_index=line["gold_index"],
target_for_fewshot_sorting=line["choices"][line["gold_index"]],
)


Expand Down
6 changes: 2 additions & 4 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,18 @@ def eval_docs(self) -> list[Doc]:
self._docs = self.remove_duplicate_docs(self._docs)
return self._docs

def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False) -> str:
def doc_to_target(self, formatted_doc: Doc) -> str:
"""
Returns the target of the given document.

Args:
formatted_doc (Doc): Formatted document.
few_shot (bool, optional): Whether the document is used for few
shot examples. Defaults to False.

Returns:
str: Target of the document, which is the correct answer for a document.
"""
# likely we mostly need one example not all
return as_list(formatted_doc.get_golds(few_shot=few_shot))[0]
return as_list(formatted_doc.get_golds())[0]

def construct_requests(
self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str
Expand Down
12 changes: 4 additions & 8 deletions src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,18 @@ def doc_to_text(doc: Doc, return_instructions: bool = False) -> Union[str, Tuple
)

@staticmethod
def doc_to_target(formatted_doc: Doc, few_shot: bool = False) -> str:
def doc_to_target(formatted_doc: Doc) -> str:
"""
Returns the target of the given document.

Args:
formatted_doc (Doc): Formatted document.
few_shot (bool, optional): Whether the document is used for few
shot examples. Defaults to False.

Returns:
str: Target of the document, which is the correct answer for a document.
"""
# likely we mostly need one example not all
return as_list(formatted_doc.get_golds(few_shot=few_shot))[0]
return as_list(formatted_doc.get_golds())[0]

def add_context_to_doc(
self,
Expand Down Expand Up @@ -255,9 +253,7 @@ def get_examples(
class FewShotSelectionMethod:
sorting: str # sorting method for the overall few shot pool (balanced, random, sequential)
with_sampling: bool # samples item randomly from the few shot pool
fewshotpool_unique: (
bool
) # set to true if you are CERTAIN there is no intersection between the few shot pool and your evaluation set
fewshotpool_unique: bool # set to true if you are CERTAIN there is no intersection between the few shot pool and your evaluation set


class FewShotSelection(Enum):
Expand Down Expand Up @@ -363,7 +359,7 @@ def _init_fewshot_sampling_balanced(
# Sort by counts of labels
label_to_instances = defaultdict(list)
for instance in fewshotpool:
target = PromptManager.doc_to_target(instance, few_shot=True)
target = instance.get_target_for_fewshot_sorting()
label_to_instances[target].append(instance)

counts_to_labels = defaultdict(list)
Expand Down
15 changes: 6 additions & 9 deletions src/lighteval/tasks/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class Doc:

# For few-shot
instruction: Optional[str] = ""
target_for_fewshot_sorting: Optional[str] = None # will probably have to be removed in the future
target_for_fewshot_sorting: Optional[str] = None

# Filled when parsing and adding the few-shot context
ctx: Optional[str] = ""
Expand All @@ -193,20 +193,17 @@ def __post_init__(self):
if self.instruction is None:
self.instruction = ""

def get_golds(self, few_shot: bool = False):
def get_golds(self):
"""Return gold targets extracted from the target dict"""
gold_indices = as_list(self.gold_index)
if few_shot and self.target_for_fewshot_sorting is not None:
choices = self.target_for_fewshot_sorting
if isinstance(choices, str): # correct choice is already selected
return choices
else:
choices = self.choices
golds = []
for gold_ix in gold_indices:
golds.extend(as_list(choices[gold_ix]))
golds.extend(as_list(self.choices[gold_ix]))
return golds

def get_target_for_fewshot_sorting(self) -> str:
return self.target_for_fewshot_sorting or as_list(self.get_golds())[0]

def __repr__(self):
doc_dict = asdict(self)
return json.dumps(doc_dict)