Skip to content

Commit

Permalink
Add Json Schema benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Aug 15, 2024
1 parent 32c2bef commit daacc8a
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 7 deletions.
61 changes: 57 additions & 4 deletions src/lfe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Benchmark the lm-format-enforcer library."""
from lmformatenforcer import RegexParser, TokenEnforcer
from lmformatenforcer import JsonSchemaParser, RegexParser, TokenEnforcer
from lmformatenforcer.integrations.transformers import (
build_token_enforcer_tokenizer_data,
)
Expand All @@ -12,7 +12,7 @@
"google/gemma-2-2b-it", # 256,128 tokens vocabulary
]

case = [
regex_case = [
(r"\d{3}-\d{2}-\d{4}", "203-22-1234"),
(
r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?",
Expand All @@ -25,8 +25,8 @@
]


class LMFormatEnforcer:
params = [models, case]
class LMFormatEnforcerRegex:
params = [models, regex_case]
param_names = ["model", "regex"]
timeout = 600

Expand All @@ -51,3 +51,56 @@ def time_lfe(self, _, regex):

for i in range(len(regex_example_tokens)):
_ = token_enforcer.get_allowed_tokens(regex_example_tokens[: i + 1])


json_case = [
(
{
"$defs": {
"Armor": {
"enum": ["leather", "chainmail", "plate"],
"title": "Armor",
"type": "string",
}
},
"properties": {
"name": {"maxLength": 10, "title": "Name", "type": "string"},
"age": {"title": "Age", "type": "integer"},
"armor": {"$ref": "#/$defs/Armor"},
"strength": {"title": "Strength", "type": "integer"},
},
"required": ["name", "age", "armor", "strength"],
"title": "Character",
"type": "object",
},
"""{'name': 'Super Warrior', 'age': 26, 'armor': 'leather', 'armor': 10}""",
)
]


class LMFormatEnforcerJsonSchema:
params = [models, json_case]
param_names = ["model", "json"]
timeout = 600

def setup(self, model, _):
"""Set up the benchmark.
We convert the tokenizer during set up as this only
needs to be done once for a given model.
"""
self.tokenizer = AutoTokenizer.from_pretrained(
model, clean_up_tokenization_spaces=True
)
self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)

def time_lfe(self, _, json):
json_string, json_example = json
json_example_tokens = self.tokenizer.encode(json_example)

parser = JsonSchemaParser(json_string)
token_enforcer = TokenEnforcer(self.tokenizer_data, parser)

for i in range(len(json_example_tokens)):
_ = token_enforcer.get_allowed_tokens(json_example_tokens[: i + 1])
71 changes: 68 additions & 3 deletions src/outlines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Benchmark the Outlines library."""
import json

from transformers import AutoTokenizer

from outlines.fsm.guide import RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.models.transformers import TransformerTokenizer

models = [
Expand All @@ -11,7 +14,7 @@
"google/gemma-2-2b-it", # 256,128 tokens vocabulary
]

case = [
regex_case = [
(r"\d{3}-\d{2}-\d{4}", "203-22-1234"),
(
r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?",
Expand All @@ -24,8 +27,8 @@
]


class Outlines:
params = [models, case]
class OutlinesRegex:
params = [models, regex_case]
param_names = ["model", "regex"]
timeout = 600

Expand Down Expand Up @@ -57,3 +60,65 @@ def time_outlines(self, _, regex):
for token in regex_example_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)


json_case = [
(
{
"$defs": {
"Armor": {
"enum": ["leather", "chainmail", "plate"],
"title": "Armor",
"type": "string",
}
},
"properties": {
"name": {"maxLength": 10, "title": "Name", "type": "string"},
"age": {"title": "Age", "type": "integer"},
"armor": {"$ref": "#/$defs/Armor"},
"strength": {"title": "Strength", "type": "integer"},
},
"required": ["name", "age", "armor", "strength"],
"title": "Character",
"type": "object",
},
"""{'name': 'Super Warrior', 'age': 26, 'armor': 'leather', 'armor': 10}""",
)
]


class OutlinesJsonSchema:
params = [models, json_case]
param_names = ["model", "json"]
timeout = 600

def setup(self, model, _):
"""Set up the benchmark.
We JIT-compile Numba functions and convert the vocabulary
during set up as this only need to be ever done once.
"""
self.tokenizer = AutoTokenizer.from_pretrained(
model, clean_up_tokenization_spaces=True
)
self.tokenizer = TransformerTokenizer(self.tokenizer)
RegexGuide("a", self.tokenizer) # JIT-compile and convert the vocabulary

def time_outlines(self, _, json_case):
"""Measure generation time with Outlines.
Outlines' generation time is split between compiling an index for each
regular expression, and walking this index while generating tokens.
"""
json_string, json_example = json_case
json_example_tokens = self.tokenizer.encode(json_example)[0][0]

regex_string = build_regex_from_schema(json.dumps(json_string))
guide = RegexGuide(regex_string, self.tokenizer)

state = 0
for token in json_example_tokens:
_ = guide.get_next_instruction(state)
state = guide.get_next_state(state, token)

0 comments on commit daacc8a

Please sign in to comment.