Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rejection sampling analysis #253

Draft
wants to merge 53 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
9035816
move rejection sampling within its own directory
nouhadziri Aug 12, 2024
e97eb4e
move rejection sampling within its own directory
nouhadziri Aug 12, 2024
4bfc79b
add api-based models to generating completions
nouhadziri Aug 12, 2024
7a0abc4
debug
nouhadziri Aug 12, 2024
034259d
debug
nouhadziri Aug 12, 2024
b5990ef
debug
nouhadziri Aug 12, 2024
d44e1cb
debug
nouhadziri Aug 12, 2024
1c99e32
debug
nouhadziri Aug 13, 2024
6013160
do not apply tokenization with chat template
nouhadziri Aug 13, 2024
c47a16e
change the tokenizer
nouhadziri Aug 13, 2024
91fd8b8
change the tokenizer
nouhadziri Aug 13, 2024
7093b10
fix template
nouhadziri Aug 13, 2024
805fcac
fix arg issue
nouhadziri Aug 13, 2024
f93e738
add mode to arg
nouhadziri Aug 13, 2024
77d257b
add mode to arg
nouhadziri Aug 13, 2024
b2a4b83
change the skill
nouhadziri Aug 13, 2024
b80894a
remove main from api_generate.py
nouhadziri Aug 13, 2024
fca6f6d
modify template chat
nouhadziri Aug 13, 2024
6541c9f
modify template chat
nouhadziri Aug 13, 2024
1630b4b
return only text
nouhadziri Aug 13, 2024
ebacf9e
return only text
nouhadziri Aug 13, 2024
45d14b3
return only text
nouhadziri Aug 13, 2024
60a7a87
fix dict
nouhadziri Aug 13, 2024
33e8f8c
fix dict
nouhadziri Aug 13, 2024
9d99744
fix dict
nouhadziri Aug 13, 2024
7ce7419
fix dict
nouhadziri Aug 13, 2024
f3f1466
fix dict
nouhadziri Aug 13, 2024
97b812e
Merge branch 'main' into llm-as-judge
nouhadziri Aug 13, 2024
7fe90fd
fix dict
nouhadziri Aug 13, 2024
fe0cfe3
add LM as a scorer
nouhadziri Aug 13, 2024
2ac539a
add LM as a scorer
nouhadziri Aug 13, 2024
b54dc5e
add LM as a scorer
nouhadziri Aug 13, 2024
a0ae60b
add LM as a scorer
nouhadziri Aug 13, 2024
f7b9188
change the name of the model
nouhadziri Aug 13, 2024
28dfddf
fix rank bug
nouhadziri Aug 13, 2024
e50ef23
process score
nouhadziri Aug 13, 2024
4ffdf75
process score
nouhadziri Aug 13, 2024
a352b69
edit the template
nouhadziri Aug 13, 2024
4e89a5e
extract score
nouhadziri Aug 13, 2024
75e7f66
compute scores for reference responses
nouhadziri Aug 13, 2024
e3d6b07
compute scores for reference responses
nouhadziri Aug 13, 2024
4abbf4f
add tqdm
nouhadziri Aug 13, 2024
d2ce21a
add tqdm
nouhadziri Aug 13, 2024
089726f
fix score bug
nouhadziri Aug 13, 2024
fe0533d
fix item bug
nouhadziri Aug 13, 2024
95213b3
fix item bug
nouhadziri Aug 13, 2024
4f7acc0
fixed code style
nouhadziri Aug 13, 2024
e436c31
fixed code style
nouhadziri Aug 13, 2024
93c54e0
.idea added to .gitignore
nouhadziri Aug 13, 2024
d5b43ee
.idea dir removed
nouhadziri Aug 13, 2024
107ac32
remove unused llm_as_judge file
nouhadziri Aug 13, 2024
dacf4b3
Update rejection_sampling.md
nouhadziri Aug 13, 2024
ab96ef5
Add analysis script
vwxyzjn Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,5 @@ dmypy.json

# Pyre type checker
.pyre/

.idea/
6 changes: 5 additions & 1 deletion docs/algorithms/rejection_sampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ different number of completions per prompt.

# Debug run (use an interactive session)

This code supports HF models, local models and also API-based models (e.g., `gpt-4`). For generating completions, the code now accepts one model at a time, but we're working on adding an ensemble of models. Stay tuned.
```bash
## tulu v3 recipe
# 1. first sample a bunch of completions given prompts
Expand All @@ -16,7 +17,10 @@ python open_instruct/generation.py \
--n 3 \
--save_filename output/completions.jsonl \
--sanity_check \

```
### Scoring completions
You can use either a single RM to score responses or a list of RMs. In the latter case, we will take the majority vote to compute the final score. The RMs can be models explicitly trained as RMs, HF LMs, or API-based models.
```
# 2. tokenize them and run a reward model to filter them
python open_instruct/rejection_sampling.py \
--input_filename output/completions.jsonl \
Expand Down
2 changes: 1 addition & 1 deletion open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
import transformers
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from huggingface_hub import HfApi
from rich import print as rprint
from rich.console import Console
from rich.table import Table
from rich.text import Text
from torch.nn.parallel.distributed import DistributedDataParallel
from transformers import PreTrainedModel, PreTrainedTokenizer
from huggingface_hub import HfApi


@dataclass
Expand Down
Empty file.
93 changes: 93 additions & 0 deletions open_instruct/rejection_sampling/analyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List
from datasets import load_dataset
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.columns import Columns
import numpy as np
from transformers import HfArgumentParser


def maybe_markdown(content: str, markdown: bool) -> str:
return Markdown(content) if markdown else content

def print_hf_chosen_rejected(chosen: List[Dict[str, str]], rejected: List[Dict[str, str]], markdown: bool = False, chosen_key: str = "chosen", rejected_key: str = "rejected"):
"""here we are assuming the chosen[:-1] == rejected[:-1]"""
assert len(chosen) == len(rejected)
assert chosen[:-1] == rejected[:-1]
console = Console()
colors = ["red", "green"]
color_idx = 0
console.rule(f"[bold yellow]The number of turns is {len(chosen)}")
for i in range(len(chosen) - 1):
message = chosen[i]
role = message["role"]
content = maybe_markdown(message["content"], markdown)
console.print(Panel(content, title_align="left", title=role, border_style=colors[color_idx]))
color_idx = (color_idx + 1) % 2
half_width = int(0.48 * console.width)
columns = Columns(
[
Panel(maybe_markdown(chosen[-1]["content"], markdown), width=half_width, title=chosen_key, border_style="green"),
Panel(maybe_markdown(rejected[-1]["content"], markdown), width=half_width, title=rejected_key, border_style="red"),
],
)

console.print(Panel(columns, title=chosen[-1]["role"], border_style=colors[color_idx]))



@dataclass
class Args:
rejection_sampled_dataset: str = "vwxyzjn/rejection_sampling_31313"
shuffle: bool = False

def main(args: Args):
args = HfArgumentParser(Args).parse_args_into_dataclasses()[0]
ds = load_dataset("vwxyzjn/rejection_sampling_31313", split="train")
if args.shuffle:
ds = ds.shuffle()

print("🚀 Dataset loaded, starting to analyze...")
chosen_scores = defaultdict(list)
rejected_scores = defaultdict(list)
reference_scores = defaultdict(list)
chosen_lengths = []
rejected_lengths = []
reference_length = []
for example in ds:
chosen_lengths.append(len(example["chosen"][-1]["content"]))
rejected_lengths.append(len(example["rejected"][-1]["content"]))
reference_length.append(len(example["reference_completion"]))
for key in example["chosen_score"]:
chosen_scores[key].append(example["chosen_score"][key])
rejected_scores[key].append(example["rejected_score"][key])
reference_scores[key].append(example["reference_completion_score"][key])

print(f"chosen: mean length = {np.mean(chosen_lengths)}")
print(f"rejected: mean length = {np.mean(rejected_lengths)}")
print(f"reference: mean length = {np.mean(reference_length)}")
for key in example["chosen_score"]:
print(f"{key=}")
print(f"chosen: mean score = {np.mean(chosen_scores[key])}")
print(f" std score = {np.std(chosen_scores[key])}")
print(f"rejected: mean score = {np.mean(rejected_scores[key])}")
print(f" std score = {np.std(rejected_scores[key])}")
print(f"reference: mean score = {np.mean(reference_scores[key])}")
print(f" std score = {np.std(reference_scores[key])}")

for i in range(len(chosen_scores[key])):
if reference_scores[key][i] > chosen_scores[key][i]:
print("reference is better than chosen")
print_hf_chosen_rejected(
ds[i]["chosen"][:-1] + [{"role": "assistant", "content": ds[i]["reference_completion"]}],
ds[i]["chosen"],
chosen_key="reference",
rejected_key="chosen",
)
input("Press Enter to continue...")

if __name__ == "__main__":
main(Args())
75 changes: 75 additions & 0 deletions open_instruct/rejection_sampling/api_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# main.py

import asyncio
import re
from dataclasses import dataclass
from typing import List, Optional

from openai import AsyncOpenAI
from prompt_templates import get_generation_template, get_judgment_template
from tqdm.asyncio import tqdm


@dataclass
class LLMGenerationConfig:
n: int = 64
model: str = "gpt-3.5-turbo-0125"
max_parallel_requests: Optional[int] = None

def __post_init__(self):
if "gpt-3.5" in self.model:
self.max_parallel_requests = 11
elif "gpt-4" in self.model:
self.max_parallel_requests = 13


@dataclass
class Args:
output_path: Optional[str] = None
num_trials: int = 1
skill: str = "summarization"
mode: str = "generation" # Can be "generation" or "judgment"


class LLMProcessor:
def __init__(self, config: LLMGenerationConfig):
self.config = config
self.async_client = AsyncOpenAI()

async def process_text(self, data: dict, i: int, limiter: asyncio.Semaphore, args: Args):
if args.mode == "generation":
template = get_generation_template(args.skill)
text = template.format(prompt=data)
else: # judgment mode
template = get_judgment_template(args.skill)
text = template.format(prompt=data["prompt"], response=data["response"])

async with limiter:
while True:
try:
response = await self.async_client.chat.completions.create(
model=self.config.model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": text},
],
)
response = response.choices[0].message.content
match = re.search(r"Total score:\s*(\d+)", response)
if match:
total_score = int(match.group(1))
else:
total_score = -1
response = total_score
break
except Exception as e:
print(f"Error in {i}: {e}")
await asyncio.sleep(30)

return response

async def process_batch(self, data_list: List[dict], args: Args):
limiter = asyncio.Semaphore(self.config.max_parallel_requests)
tasks = [self.process_text(data, i, limiter, args) for i, data in enumerate(data_list)]
# Use tqdm to track progress
return await tqdm.gather(*tasks, total=len(tasks), desc="Processing Batch")
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
import copy
import json
import multiprocessing
import os
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Optional

import pandas as pd
from api_generate import LLMGenerationConfig, LLMProcessor # Import your classes
from datasets import load_dataset
from rich.console import Console
from rich.pretty import pprint
Expand All @@ -33,12 +36,13 @@
class Args:
model_name_or_path: str = "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr"
save_filename: str = "completions.jsonl"
mode: str = "generation"
skill: str = "chat"


@dataclass
class GenerationArgs:
n: int = 1
"""the number of samples to generate per prompt"""
temperature: float = 0.8
response_length: int = 53
tensor_parallel_size: int = 1
Expand Down Expand Up @@ -66,27 +70,15 @@ def print_rich_table(df: pd.DataFrame) -> Table:
console.print(table)


def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
ds = load_dataset(dataset_args.dataset_name)
if dataset_args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(min(dataset_args.sanity_check_size, len(ds[key]))))
if dataset_args.dataset_end_idx is None:
dataset_args.dataset_end_idx = len(ds[dataset_args.dataset_train_split])
for key in ds:
ds[key] = ds[key].select(range(dataset_args.dataset_start_idx, dataset_args.dataset_end_idx))
pprint([dataset_args, args, gen_args])
async def generate_with_openai(model_name: str, data_list: list, args: Args, n: int):
config = LLMGenerationConfig(model=model_name, n=n)
processor = LLMProcessor(config)
results = await processor.process_batch(data_list, args)
return results

# DATASET specific logic: in this dataset the prompt is simply just a list of strings
ds = ds.map(
lambda x: {"prompt_token_ids": tokenizer.apply_chat_template(x["messages"][:-1])},
num_proc=multiprocessing.cpu_count(),
)
prompt_token_ids = ds[dataset_args.dataset_train_split]["prompt_token_ids"]

# Generate using vLLM
llm = LLM(model=args.model_name_or_path, tensor_parallel_size=gen_args.tensor_parallel_size)
def generate_with_vllm(model_name_or_path: str, prompt_token_ids, gen_args: GenerationArgs):
llm = LLM(model=model_name_or_path, tensor_parallel_size=gen_args.tensor_parallel_size)
outputs = llm.generate(
prompt_token_ids=prompt_token_ids,
sampling_params=SamplingParams(
Expand All @@ -97,6 +89,56 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):
include_stop_str_in_output=True,
),
)

return [
{
"outputs": [asdict(out) for out in output.outputs],
"prompt": output.prompt,
"prompt_logprobs": output.prompt_logprobs,
"metrics": output.metrics,
}
for output in outputs
]


def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):

ds = load_dataset(dataset_args.dataset_name)
if dataset_args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(min(dataset_args.sanity_check_size, len(ds[key]))))
if dataset_args.dataset_end_idx is None:
dataset_args.dataset_end_idx = len(ds[dataset_args.dataset_train_split])
for key in ds:
ds[key] = ds[key].select(range(dataset_args.dataset_start_idx, dataset_args.dataset_end_idx))
pprint([dataset_args, args, gen_args])

if "gpt-3.5" in args.model_name_or_path or "gpt-4" in args.model_name_or_path:
use_openai = True
else:
use_openai = False

if not use_openai:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

ds = ds.map(
lambda x: {"prompt_token_ids": tokenizer.apply_chat_template(x["messages"][:-1])},
num_proc=multiprocessing.cpu_count(),
)
prompt_token_ids = ds[dataset_args.dataset_train_split]["prompt_token_ids"]
# Generate using vLLM.
outputs = generate_with_vllm(args.model_name_or_path, prompt_token_ids, gen_args)

else:
tokenizer = AutoTokenizer.from_pretrained("allenai/llama-3-tulu-2-8b")
ds = ds.map(
lambda x: {"prompt": tokenizer.apply_chat_template(x["messages"][:-1], tokenize=False)},
num_proc=multiprocessing.cpu_count(),
)
messages = ds[dataset_args.dataset_train_split]["prompt"]
responses = asyncio.run(generate_with_openai(args.model_name_or_path, messages, args, gen_args.n))
outputs = [{"outputs": [{"text": response}]} for response in responses]

# Assuming we generate n=3 completions per prompt, the outputs will look like:
# prompt | completions
# -------|------------
Expand All @@ -107,19 +149,13 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):
# ...
table = defaultdict(list)
for output, messages in zip(outputs, ds[dataset_args.dataset_train_split]["messages"]):
# if the model completions are exactly the same across all completions, we can skip this
if len(set([item.text for item in output.outputs])) == 1:
continue

for item in output.outputs:
for item in output["outputs"]:
new_messages = copy.deepcopy(messages[:-1])
new_messages.append({"role": "assistant", "content": item.text})
new_messages.append({"role": "assistant", "content": item["text"]})
table["messages"].append(new_messages)
table["model_completion"].append(item.text)
table["model_completion"].append(item["text"])
table["reference_completion"].append(messages[-1]["content"])

# print_rich_table(pd.DataFrame(table)) # uncomment this line to print the table

# Save results
os.makedirs(os.path.dirname(args.save_filename), exist_ok=True)
with open(args.save_filename, "w") as outfile:
Expand All @@ -129,6 +165,7 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):
"messages": table["messages"][i],
"model_completion": table["model_completion"][i],
"reference_completion": table["reference_completion"][i],
"model": args.model_name_or_path,
},
outfile,
)
Expand Down
Loading
Loading