Skip to content

Commit

Permalink
Time tree building for LFE
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 16, 2024
1 parent e7c5dc2 commit 77fdf8f
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions src/benchmark_lfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,20 @@ def setup(self, model, _):
self.tokenizer = AutoTokenizer.from_pretrained(
model, clean_up_tokenization_spaces=True
)
self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)

def time_lfe(self, _, regex_name):
regex_string = regex_cases[regex_name]["regex"]
regex_samples = regex_cases[regex_name]["samples"]

parser = RegexParser(regex_string)
token_enforcer = TokenEnforcer(self.tokenizer_data, parser)
tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)
token_enforcer = TokenEnforcer(tokenizer_data, parser)

for regex_sample in regex_samples:
regex_sample_tokens = self.tokenizer.encode(regex_sample)
for i in range(len(regex_sample_tokens)):
_ = token_enforcer.get_allowed_tokens(regex_sample_tokens[: i + 1])

def teardown(self, *args):
del self.tokenizer_data


class LMFormatEnforcerJsonSchema:
params = [models, json_cases.keys()]
Expand All @@ -56,19 +53,16 @@ def setup(self, model, _):
self.tokenizer = AutoTokenizer.from_pretrained(
model, clean_up_tokenization_spaces=True
)
self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)

def time_lfe(self, _, json_schema_name):
json_string = json_cases[json_schema_name]["schema"]
json_samples = json_cases[json_schema_name]["samples"]

parser = JsonSchemaParser(json_string)
token_enforcer = TokenEnforcer(self.tokenizer_data, parser)
tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)
token_enforcer = TokenEnforcer(tokenizer_data, parser)

for json_sample in json_samples:
json_sample_tokens = self.tokenizer.encode(json_sample)
for i in range(len(json_sample_tokens)):
_ = token_enforcer.get_allowed_tokens(json_sample_tokens[: i + 1])

def teardown(self, *args):
del self.tokenizer_data

0 comments on commit 77fdf8f

Please sign in to comment.