Skip to content

Commit

Permalink
Add 100 Samples Per Regex / JSON Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Oct 13, 2024
1 parent 717b3cc commit 91c66eb
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 83 deletions.
32 changes: 18 additions & 14 deletions src/benchmark_lfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@


class LMFormatEnforcerRegex:
params = [models, regex_cases]
param_names = ["model", "regex"]
params = [models, regex_cases.keys()]
param_names = ["model", "regex_name"]
timeout = 600

def setup(self, model, _):
Expand All @@ -25,20 +25,22 @@ def setup(self, model, _):
)
self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)

def time_lfe(self, _, regex):
regex_string, regex_example = regex["regex"], regex["example"]
regex_example_tokens = self.tokenizer.encode(regex_example)
def time_lfe(self, _, regex_name):
regex_string = regex_cases[regex_name]["regex"]
regex_samples = regex_cases[regex_name]["samples"]

parser = RegexParser(regex_string)
token_enforcer = TokenEnforcer(self.tokenizer_data, parser)

for i in range(len(regex_example_tokens)):
_ = token_enforcer.get_allowed_tokens(regex_example_tokens[: i + 1])
for regex_sample in regex_samples:
regex_sample_tokens = self.tokenizer.encode(regex_sample)
for i in range(len(regex_sample_tokens)):
_ = token_enforcer.get_allowed_tokens(regex_sample_tokens[: i + 1])


class LMFormatEnforcerJsonSchema:
params = [models, json_cases]
param_names = ["model", "json"]
params = [models, json_cases.keys()]
param_names = ["model", "json_schema_name"]
timeout = 600

def setup(self, model, _):
Expand All @@ -53,12 +55,14 @@ def setup(self, model, _):
)
self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)

def time_lfe(self, _, json):
json_string, json_example = json["schema"], json["example"]
json_example_tokens = self.tokenizer.encode(json_example)
def time_lfe(self, _, json_schema_name):
json_string = json_cases[json_schema_name]["schema"]
json_samples = json_cases[json_schema_name]["samples"]

parser = JsonSchemaParser(json_string)
token_enforcer = TokenEnforcer(self.tokenizer_data, parser)

for i in range(len(json_example_tokens)):
_ = token_enforcer.get_allowed_tokens(json_example_tokens[: i + 1])
for json_sample in json_samples:
json_sample_tokens = self.tokenizer.encode(json_sample)
for i in range(len(json_sample_tokens)):
_ = token_enforcer.get_allowed_tokens(json_sample_tokens[: i + 1])
42 changes: 24 additions & 18 deletions src/benchmark_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


class OutlinesRegex:
params = [models, regex_cases]
param_names = ["model", "regex"]
params = [models, regex_cases.keys()]
param_names = ["model", "regex_name"]
timeout = 1200

def setup(self, model, _):
Expand All @@ -28,7 +28,7 @@ def setup(self, model, _):
self.tokenizer = TransformerTokenizer(self.tokenizer)
RegexGuide("a", self.tokenizer) # JIT-compile and convert the vocabulary

def time_outlines(self, _, regex):
def time_outlines(self, _, regex_name):
"""Measure generation time with Outlines.
Outlines' generation time is split between compiling an index for each
Expand All @@ -37,19 +37,23 @@ def time_outlines(self, _, regex):
"""
caching.clear_cache()

regex_string, regex_example = regex["regex"], regex["example"]
regex_example_tokens = self.tokenizer.encode(regex_example)[0][0]
regex_string = regex_cases[regex_name]["regex"]
regex_samples = regex_cases[regex_name]["samples"]

guide = RegexGuide(regex_string, self.tokenizer)

state = 0
for token in regex_example_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)
for regex_sample in regex_samples:
regex_sample_tokens = self.tokenizer.encode(regex_sample)[0][0]
state = guide.initial_state
for token in regex_sample_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)


class OutlinesJsonSchema:
params = [models, json_cases]
param_names = ["model", "json"]
params = [models, json_cases.keys()]
param_names = ["model", "json_schema_name"]

timeout = 1200

def setup(self, model, _):
Expand All @@ -65,20 +69,22 @@ def setup(self, model, _):
self.tokenizer = TransformerTokenizer(self.tokenizer)
RegexGuide("a", self.tokenizer) # JIT-compile and convert the vocabulary

def time_outlines(self, _, json_case):
def time_outlines(self, _, json_schema_name):
"""Measure generation time with Outlines.
Outlines' generation time is split between compiling an index for each
regular expression, and walking this index while generating tokens.
"""
json_string, json_example = json_case["schema"], json_case["example"]
json_example_tokens = self.tokenizer.encode(json_example)[0][0]
json_string = json_cases[json_schema_name]["schema"]
json_samples = json_cases[json_schema_name]["samples"]

regex_string = build_regex_from_schema(json.dumps(json_string))
guide = RegexGuide(regex_string, self.tokenizer)

state = 0
for token in json_example_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)
for json_sample in json_samples:
json_sample_tokens = self.tokenizer.encode(json_samples)[0][0]
state = guide.initial_state
for token in json_sample_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)
41 changes: 23 additions & 18 deletions src/benchmark_outlines_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@


class OutlinesCoreRegex:
params = [models, regex_cases]
param_names = ["model", "regex"]
params = [models, regex_cases.keys()]
param_names = ["model", "regex_name"]
timeout = 600

def setup(self, model, _):
Expand All @@ -25,26 +25,29 @@ def setup(self, model, _):
)
self.tokenizer = TransformerTokenizer(self.tokenizer)

def time_outlines_core(self, _, regex):
def time_outlines_core(self, _, regex_name):
"""Measure generation time with Outlines.
Outlines' generation time is split between compiling an index for each
regular expression, and walking this index while generating tokens.
"""
regex_string, regex_example = regex["regex"], regex["example"]
regex_example_tokens = self.tokenizer.encode(regex_example)[0][0]
regex_string = regex_cases[regex_name]["regex"]
regex_samples = regex_cases[regex_name]["samples"]

guide = RegexGuide(regex_string, self.tokenizer)

state = 0
for token in regex_example_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)
for regex_sample in regex_samples:
regex_sample_tokens = self.tokenizer.encode(regex_sample)[0][0]
state = guide.initial_state
for token in regex_sample_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)


class OutlinesCoreJsonSchema:
params = [models, json_cases]
param_names = ["model", "json"]
params = [models, json_cases.keys()]
param_names = ["model", "json_schema_name"]
timeout = 600

def setup(self, model, _):
Expand All @@ -59,20 +62,22 @@ def setup(self, model, _):
)
self.tokenizer = TransformerTokenizer(self.tokenizer)

def time_outlines_core(self, _, json_case):
def time_outlines_core(self, _, json_schema_name):
"""Measure generation time with Outlines.
Outlines' generation time is split between compiling an index for each
regular expression, and walking this index while generating tokens.
"""
json_string, json_example = json_case["schema"], json_case["example"]
json_example_tokens = self.tokenizer.encode(json_example)[0][0]
json_string = json_cases[json_schema_name]["schema"]
json_samples = json_cases[json_schema_name]["samples"]

regex_string = build_regex_from_schema(json.dumps(json_string))
guide = RegexGuide(regex_string, self.tokenizer)

state = 0
for token in json_example_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)
for json_sample in json_samples:
json_sample_tokens = self.tokenizer.encode(json_samples)[0][0]
state = guide.initial_state
for token in json_sample_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)
71 changes: 38 additions & 33 deletions src/data.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,45 @@
import json
from pathlib import Path

SAMPLES_PATH = Path(__file__).parent / "samples"


models = [
"NousResearch/Nous-Hermes-llama-2-7b", # 32,000 tokens vocabulary
"gpt2", # 50,257 tokens vocabulary
"NousResearch/Hermes-3-Llama-3.1-8B", # 128,256 tokens vocabulary
"unsloth/gemma-2-2b-it-bnb-4bit", # 256,128 tokens vocabulary
]

regex_cases = [
{
"name": "Phone Number",
"regex": r'\d{3}-\d{2}-\d{4}',
"example": '203-22-1234'

regex_cases = {
"Phone Number": {
"regex": r"\d{3}-\d{3}-\d{4}",
"samples": json.load(open(SAMPLES_PATH / "phone_number.json")),
},
{
"name": "URL",
"regex": r'(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?',
"example": 'https://github.com/outlines-dev/outlines'
"URL": {
"regex": r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?",
"samples": json.load(open(SAMPLES_PATH / "url.json")),
},
{
"name": "GSM8K",
"regex": r'A: [\w \.\*\-=\+,\?/]{10,50}\. The answer is [1-9][0-9]{0,9}\.',
"example": 'A: Some thoughts before answering. The answer is 42.'
"GSM8K": {
"regex": r"A: [\w \.\*\-=\+,\?/]{10,50}\. The answer is [1-9][0-9]{0,9}\.",
"samples": json.load(open(SAMPLES_PATH / "gsm8k.json")),
# gsm8k.json attribution: https://huggingface.co/datasets/thesven/gsm8k-reasoning
},
{
"name": "Complex string",
"regex": r'(0|[1-9][0-9]*)|true|false|([a-zA-Z_][a-zA-Z_0-9]*)',
"example": 'AVeryLongStringtoTest1234'
"Complex string": {
"regex": r"(0|[1-9][0-9]*)|true|false|([a-zA-Z_][a-zA-Z_0-9]*)",
"samples": json.load(open(SAMPLES_PATH / "complex_str.json")),
},
{
"name": "Long integer",
"regex": r'\+[1-9]\d{1,14}',
"example": '1234567891234'
}
]
"Long integer": {
"regex": r"\+[1-9]\d{1,14}",
"samples": json.load(open(SAMPLES_PATH / "long_integer.json")),
},
}


json_cases = [
{
"name": "RPG character",
"schema":
{
json_cases = {
"RPG character": {
"schema": {
"$defs": {
"Armor": {
"enum": ["leather", "chainmail", "plate"],
Expand All @@ -55,10 +57,11 @@
"title": "Character",
"type": "object",
},
"example": """{'name': 'Super Warrior', 'age': 26, 'armor': 'leather', 'armor': 10}""",
"samples": map(
json.dumps, json.load(open(SAMPLES_PATH / "rpg_characters.json"))
),
},
{
"name": "Simple nested schema",
"Simple nested schema": {
"schema": {
"$schema": "http://json-schema.org/draft-04/schema#",
"title": "Schema for a recording",
Expand Down Expand Up @@ -91,6 +94,8 @@
},
"required": ["id", "work", "recording_artists"],
},
"example": """{'id': 999, 'work': {'id': 1, 'name': 'Strasbourg Saint-Denis', 'composer': 'Roy Hargrove'}, 'recording_artists': [{'id': 2, 'name': 'Roy Hargrove', 'functions': ['Trumpet', 'Singing']}]}""",
"samples": map(
json.dumps, json.load(open(SAMPLES_PATH / "recording_schema.json"))
),
},
]
}
1 change: 1 addition & 0 deletions src/samples/complex_str.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
["falseJb", "false0IUnuntrueb_LgozC2VciR4TSU", "truefalsetruetrue", "60falsetrue0truefalsefalseucfalse", "GG6LOxmjtrue0", "ybZXifalsetrueIi3Ftrue", "falsetrueDlled4UiW0trueZJsNUjKfs", "4", "falseWkZpEfalse19falsefalsefalsetrue0", "true41falsectrue0falsetrue", "false11trues82", "true39falsetrue73true0falsetrue", "NDd9falsedjy3fGB", "true", "true58true0truefalse", "falset77_ZzU0sFE42etrue0D9", "52sYo7oF8YDPtrueu5x7eBuByBqZJNb00", "false22false8false", "falsefalsewkXfalseNl00true3", "FQJDLnTDVstku3J0X8d5RaerNJfaO96", "truefalsefalse0", "truehOfalse", "truefalseetCOz", "truefalsefalseOtruetrue0LHT0false", "0falsefalseWW3FVDzctruetruetrue", "false", "truetruefalsetruepZWNyfi12bkU", "truefalseigp1ub3", "truetrue64falsefalse", "SWAOdHtBfalse37UHNc1hlfAX_hEfalse7", "truetruefalseUMaFReibQfalse76_X8MWwTkRZfalse", "098truefalsetrue", "VNB0040rRrSOG048false", "xZRfalsetruetrue", "42", "falsetruetrueR", "Afalsetrue", "false0MUqmMTNtrueQLfwxtruetruetruefalse", "truearAcBkNR426yPWtruefalse37YgMuFwC2nfalse0", "uDvnVfalse2falsefalse24", "1iEwsRFaXzPj", "falsetruefalse40truefalsefalsetrueCVTyhXpeufalse", "iVdgt2_24trueO57", "un7HivkLu360falseXI9dfg0BEU53izLz11falsefalse", "Wfalse050false0", "grRmM2N7R0iQ0falsefalse0amOajE", "falsefalsefalsed70", "Fodkhdk1rXfalseV8fBRtrue034", "truek9gRy7Ll0qPx4gXTY_W", "truedI5xBI3cTi6", "falsetrueK541tVn1kofalsetrue", "QlWNmtruefalsevHmJX8i00falseap", "QuKVqKdmCfalsemx5RWQa", "R92truefalsetrue", "pRQW6krgFtrue0kqT96b2truelUarMp30v_68w66", "Nnfalse0", "false", "kVnGMeGfalse", "false", "533true096zyxAZXkVsV", "ZHGMEshg67", "27truefalsetrue", "0truetruetrueE1xfalse67", "false67DOVb2Ohcfalse", "false", "83truefalsefalsefalsefalse1843HAUZ", "falsetrueSgcSyFrMLtrue", "57truezn23BcwaTfalse", "BQsQY2W18false14", "false148trueIBjNKK7mWY", "036", "falsecZYAdPOjGkofalse2falsemN3ktruefalsetrue", "87", "true0falsefalse", "wtruetruetruetrue0true0", "truetruefalsetruefalsefalsetUxqN3BsGJ0", "trueryEUpMh0_UQjnA5AOhP6519", "falsefalsetrue0truetruefalsehLE2itrue57", "truefalsefalsefalse350", "zOyG0truefalsefalse", "falseYlQfg", "21Og", "0false00falsexA2mX7true", "96truetrue", "SFtg_HU_5GvFAkP0Pxw8K5ftruefalse", "ZtruGMWBuItruectW510falsetrue", "trueBvPmtrue00", "_fDdoTYwtrue", "295true97", "true", "false", "false0H0true000", "93ClaPkD41h7false76falsemtruePrRNiFsAOcY1YxC", "falsefalse170truerfmQImvL0", "false1", "dpvYu", "false0", "0true", "trueDpQs13MtrueEl9619true", "true46"]
Loading

0 comments on commit 91c66eb

Please sign in to comment.