Skip to content

Commit

Permalink
Add tokenizer data build time to lm-format-enforcer timings
Browse files Browse the repository at this point in the history
Indeed, this needs to be run every time one starts a new process to
perform structured generation. This is equivalent to `outlines`'s
compilation step.
  • Loading branch information
rlouf committed Oct 21, 2024
1 parent c86e55d commit f208928
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/benchmark_lfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def do_setup(self, model, samples):
self.tokenizer = AutoTokenizer.from_pretrained(
model, clean_up_tokenization_spaces=True
)
self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)
self.all_tokenized_samples = [
self.tokenizer.encode(sample) for sample in samples
]
Expand All @@ -27,9 +26,6 @@ def _get_first_token(self, token_enforcer):
"""Get first token to verify lazy index is fully warmed up"""
_ = token_enforcer.get_allowed_tokens(self.all_tokenized_samples[0][:1])

def teardown(self, *args):
del self.tokenizer_data


class LMFormatEnforcerRegex(LMFormatEnforcerBenchmark):
params = [models, regex_cases.keys()]
Expand All @@ -43,7 +39,8 @@ def setup(self, model, regex_name):
def _get_enforcer(self, regex_name):
pattern = regex_cases[regex_name]["regex"]
parser = RegexParser(pattern)
return TokenEnforcer(self.tokenizer_data, parser)
tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)
return TokenEnforcer(tokenizer_data, parser)

def time_lfe_total(self, _, regex_name):
enforcer = self._get_enforcer(regex_name)
Expand Down Expand Up @@ -87,7 +84,8 @@ def setup(self, model, json_schema_name):
def _get_enforcer(self, json_schema_name):
schema = json_cases[json_schema_name]["schema"]
parser = JsonSchemaParser(schema)
return TokenEnforcer(self.tokenizer_data, parser)
tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)
return TokenEnforcer(tokenizer_data, parser)

def time_lfe_total(self, _, json_schema_name):
enforcer = self._get_enforcer(json_schema_name)
Expand Down

0 comments on commit f208928

Please sign in to comment.