From 11143dfa2d603a1f98adbcdedd546ee3722f3ea5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 11 Mar 2024 16:10:23 +0100 Subject: [PATCH] Add `Guide` interface --- docs/api/fsm.md | 1 - mkdocs.yml | 2 +- outlines/fsm/{fsm.py => guide.py} | 239 ++++++++++-------- outlines/generate/api.py | 5 +- outlines/generate/cfg.py | 4 +- outlines/generate/fsm.py | 4 +- outlines/generate/generator.py | 30 +-- outlines/generate/regex.py | 4 +- outlines/generate/text.py | 4 +- outlines/models/llamacpp.py | 16 +- outlines/serve/vllm.py | 4 +- tests/benchmark/conftest.py | 4 +- tests/benchmark/test_benchmark_json_schema.py | 4 +- tests/benchmark/test_benchmark_regex_fsm.py | 4 +- tests/fsm/{test_fsm.py => test_guide.py} | 188 +++++++++----- tests/generate/test_generator.py | 44 ++-- tests/test_grammars.py | 6 +- 17 files changed, 318 insertions(+), 245 deletions(-) delete mode 100644 docs/api/fsm.md rename outlines/fsm/{fsm.py => guide.py} (65%) rename tests/fsm/{test_fsm.py => test_guide.py} (52%) diff --git a/docs/api/fsm.md b/docs/api/fsm.md deleted file mode 100644 index 0f6a1ab0f..000000000 --- a/docs/api/fsm.md +++ /dev/null @@ -1 +0,0 @@ -::: outlines.fsm.fsm diff --git a/mkdocs.yml b/mkdocs.yml index 86ccf326d..76b0bc499 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -137,7 +137,7 @@ nav: - api/models.md - api/prompts.md - api/json_schema.md - - api/fsm.md + - api/guide.md - api/parsing.md - api/regex.md - api/samplers.md diff --git a/outlines/fsm/fsm.py b/outlines/fsm/guide.py similarity index 65% rename from outlines/fsm/fsm.py rename to outlines/fsm/guide.py index 1cf3f03b8..66b4388d0 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/guide.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING, List, NewType, Protocol, Tuple +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Protocol, Tuple, Union import interegular from lark import Lark -# from outlines.fsm.parsing import PartialLark from outlines import grammars from outlines.caching import cache from outlines.fsm.regex import create_fsm_index_tokenizer, make_deterministic_fsm @@ -11,87 +11,96 @@ if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer -FSMState = NewType("FSMState", int) +@dataclass(frozen=True) +class Write: + """Write instruction. -class FSM(Protocol): - def is_final_state(self, state: FSMState) -> bool: - ... + Attributes + ---------- + tokens + The sequence of tokens to be added to the current sequence by the + generation process. - def allowed_token_ids(self, state: FSMState) -> List[int]: - ... + """ - def next_state(self, state: FSMState, token_id: int) -> FSMState: - ... + tokens: List[int] - def copy(self) -> "FSM": - ... +@dataclass(frozen=True) +class Generate: + """Generate instruction -class StopAtEosFSM: - """FSM to generate text until EOS has been generated.""" + Attributes + ---------- + tokens + The tokens that lead to a valid completion if generated. + """ - def __init__(self, tokenizer: "Tokenizer"): - self.eos_token_id = tokenizer.eos_token_id - self.vocabulary = tokenizer.vocabulary.values() - self.start_state: FSMState = FSMState(0) - self.final_state: FSMState = FSMState(1) + tokens: List[int] - def allowed_token_ids(self, state: FSMState) -> List[int]: - """Generate a list of allowed tokens for the next step. - When in the initial state we allow every token to be generated. - In the final state the only allowed token is `stop_token_id`. +Instruction = Union[Write, Generate] - Parameters - ---------- - state - The current state of the FSM. - Returns - ------- - A list that contains the tokens to mask. +class Guide(Protocol): + """Base definition of a generation guide. - """ - if self.is_final_state(state): - return [self.eos_token_id] - return list(self.vocabulary) + A generation guide defines the behavior of a finite-state machine that guides + a text generation procedure. Unlike the DFAs built from regular expressions + guides can also emit a `Write` instructions which tells the model that it can + append a sequence of tokens (or token word) instead of generating it. - def next_state(self, state: FSMState, token_id: int) -> FSMState: - """Update the state of the FSM. + """ - The FSM stays in the initial state `0` unless the specified stop token - has been generated or the maximum number of tokens has been reached. In - which case the FSM moves to the final state `-1`. + def get_next_instruction(self, state: int) -> Instruction: + ... - Parameters - ---------- - state - The current state of the FSM. - token_id - The id of the token that was just generated. + def get_next_state(self, state: int, token_id: int) -> int: + ... - Returns - ------- - The new state of the FSM. + def is_final_state(self, state: int) -> bool: + ... + + +class StopAtEOSGuide(Guide): + """Guide to generate tokens until the EOS token has been generated.""" + + final_state = 1 + start_state = 0 + + def __init__(self, tokenizer: "Tokenizer"): + """Initialize the generation guide. + + model + The logit generator used to generate the next token. """ + self.eos_token_id = tokenizer.eos_token_id + self.vocabulary = tokenizer.vocabulary.values() + + def get_next_instruction(self, state: int) -> Instruction: + if self.is_final_state(state): + return Write([self.eos_token_id]) + return Generate(list(self.vocabulary)) + + def get_next_state(self, state: int, token_id: int) -> int: if token_id == self.eos_token_id or state == self.final_state: return self.final_state return self.start_state - def is_final_state(self, state: FSMState) -> bool: - """Determine whether the current state of the FSM is a final state.""" + def is_final_state(self, state: int): return state == self.final_state - def copy(self) -> "StopAtEosFSM": - """Create a copy of the FSM.""" + def copy(self): return self -class RegexFSM: - """FSM to generate text that is in the language of a regular expression.""" +class RegexGuide(Guide): + """Guide to generate text in the language of a regular expression.""" + + initial_state = 0 def __init__(self, regex_string: str, tokenizer): @cache() @@ -129,70 +138,66 @@ def create_states_mapping( ) self.vocabulary = list(tokenizer.vocabulary.values()) self.eos_token_id = tokenizer.eos_token_id - self.start_state = FSMState(0) self.final_states = fsm_finals | {-1} - def allowed_token_ids(self, state: FSMState) -> List[int]: - """Generate a list of allowed tokens for the next step. + def get_next_instruction(self, state: int) -> Instruction: + """Return the next instruction for guided generation. - The initialization of the FSM builds an index which maps FSM states to a - map from authorized tokens to the state in which the FSM needs to move + The initialization of the guide builds an index which maps FSM states to a + map from authorized tokens to the state in which the guide needs to move if said token is generated. Therefore the authorized tokens at the current state are the keys of the map returned by the value of the index for current state. If the current state is not contained in the end this means that we are - in a final state of the FSM. We only authorize EOS tokens in the final + in a final state of the guide. We only authorize EOS tokens in the final state. Parameters ---------- state - The current state of the FSM. + The current state of the guide. Returns ------- - A list that contains the tokens to mask. + A `Generate` instance that contains the model and the allowed token ids. """ next_tokens_to_end_states = self.states_to_token_maps.get(state) - if next_tokens_to_end_states is None: - return [self.eos_token_id] - else: - return list(next_tokens_to_end_states.keys()) + return Write([self.eos_token_id]) + + return Generate(list(next_tokens_to_end_states.keys())) - def next_state(self, state: FSMState, token_id: int) -> FSMState: - """Update the state of the FSM. + def get_next_state(self, state: int, token_id: int) -> int: + """Update the state of the guide. - We use the index to determine to which state the FSM should transition + We use the index to determine to which state the guide should transition given the token that was just generated. Parameters ---------- state - The current state of the FSM. + The current state of the guide. token_id The id of the token that was just generated. Returns ------- - The new state of the FSM. + The new state of the guide. """ if token_id == self.eos_token_id: - return FSMState(-1) - elif ( - state in self.final_states - ): # Necessary because we keep generating EOS tokens when finished + return -1 + elif state in self.final_states: return state last_token_to_end_state = self.states_to_token_maps[state] next_state = last_token_to_end_state.get(token_id) if next_state is None: - return FSMState(-1) + next_state = -1 - return FSMState(next_state) + return next_state @classmethod def from_interegular_fsm( @@ -234,16 +239,16 @@ def create_states_mapping_from_interegular_fsm( from_interegular_instance.eos_token_id = tokenizer.eos_token_id return from_interegular_instance - def is_final_state(self, state: FSMState) -> bool: + def is_final_state(self, state: int) -> bool: + """Determine whether the current state of the guide is a final state.""" return state in self.final_states - def copy(self) -> "RegexFSM": - """Create a copy of the FSM.""" + def copy(self): return self -class CFGFSM(FSM): - """FSM to generate text that is in the language of a context-free grammar.""" +class CFGGuide(Guide): + """Guide to generate text that is in the language of a context-free grammar.""" def __init__(self, cfg_string: str, tokenizer): self.cfg_string = cfg_string @@ -267,33 +272,34 @@ def __init__(self, cfg_string: str, tokenizer): self.generation = "" self.reset_state = False self.allow_eos = False - self.regex_fsm: RegexFSM + self.regex_fsm: RegexGuide self.check_last = False self.proposal_last: List[int] = [] - self.regex_fsm_last: RegexFSM + self.regex_fsm_last: RegexGuide - self.start_state = FSMState(0) - self.final_state = FSMState(-1) + self.start_state = 0 + self.final_state = -1 - def allowed_token_ids(self, state: FSMState) -> List[int]: - """Generate a list of allowed tokens for the next step. + def get_next_instruction(self, state: int) -> Instruction: + """Generate an instruction for the next step. Upon initialization, the CFG incremental parser is used to determine the first regex and construct the first FSM to generate the first terminal. This FSM is used for proposals until either: - - The FSM is exhausted, and its only remaining option is the EOS - token, in which case we feed the generated terminal to the + - The FSM is exhausted, and its only remaining option is the EOS token, + in which case we feed the generated terminal to the CFG incremental parser and allow it to propose the next regex corresponding to the next set of valid terminals. - The current FSM can be exhausted, but the EOS token is not the only - remaining option. In this case we allow proposal of current terminal extensions, - store the current FSM and its state, then also use the CFG parser - to propose a new regex corresponding to terminating the current terminal - and starting the next one. The model can then sample from either of these sets - to determine whether to extend the current terminal or terminate it and start the next one. + remaining option. In this case we allow proposal of current terminal + extensions, store the current FSM and its state, then also use the CFG + parser to propose a new regex corresponding to terminating the current + terminal and starting the next one. The model can then sample from + either of these sets to determine whether to extend the current + terminal or terminate it and start the next one. The CFG incremental parser is allowed to propose the EOS token from any accepting state, and once it is generated, the FSM will continue to always generate the EOS token. @@ -309,17 +315,24 @@ def allowed_token_ids(self, state: FSMState) -> List[int]: """ if self.is_final_state(state): - return [self.tokenizer.eos_token_id] + return Write([self.tokenizer.eos_token_id]) - proposal = [] + proposal: List[int] = [] if self.generation != "": if self.check_last: proposer = self.regex_fsm_last else: proposer = self.regex_fsm - proposal += proposer.allowed_token_ids(state) + + instruction = proposer.get_next_instruction(state) + if isinstance(instruction, Write): + proposal += instruction.tokens + else: + proposal += instruction.tokens + if self.tokenizer.eos_token_id not in proposal: - return proposal + return Generate(proposal) + self.check_last = False proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] if len(proposal) > 0: @@ -337,25 +350,31 @@ def allowed_token_ids(self, state: FSMState) -> List[int]: if self.terminal_regexps["$END"] in options: options.remove(self.terminal_regexps["$END"]) if len(options) == 0: - return [self.tokenizer.eos_token_id] + return Write([self.tokenizer.eos_token_id]) self.allow_eos = True options.add("") assert len(options) > 1 regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")" - self.regex_fsm = RegexFSM(regex_string, self.tokenizer) + self.regex_fsm = RegexGuide(regex_string, self.tokenizer) self.reset_state = True - proposal += self.regex_fsm.allowed_token_ids(self.start_state) + instruction = self.regex_fsm.get_next_instruction(self.start_state) + if isinstance(instruction, Write): + proposal += instruction.tokens + else: + proposal += instruction.tokens + if self.allow_eos: self.allow_eos = False else: proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] assert len(proposal) > 0 - return proposal - def next_state(self, state: FSMState, token_id: int) -> FSMState: - """Update the state of the FSM. + return Generate(proposal) + + def get_next_state(self, state: int, token_id: int) -> int: + """Update the state of the guide. Transitions the underlying regex FSM to its next state. If at max tokens or EOS token, transition permanently to the final state. @@ -382,18 +401,18 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: if self.check_last: if token_id in self.proposal_last: - return self.regex_fsm_last.next_state(state, token_id) + return self.regex_fsm_last.get_next_state(state, token_id) self.check_last = False if self.reset_state: self.reset_state = False state = self.start_state - return self.regex_fsm.next_state(state, token_id) + return self.regex_fsm.get_next_state(state, token_id) - def is_final_state(self, state: FSMState) -> bool: + def is_final_state(self, state: int) -> bool: return state == self.final_state - def copy(self) -> "CFGFSM": + def copy(self) -> "CFGGuide": """Create a copy of the FSM.""" - return CFGFSM(self.cfg_string, self.tokenizer) + return CFGGuide(self.cfg_string, self.tokenizer) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 8d3a88b94..97b9a981b 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -2,7 +2,6 @@ import torch -from outlines.fsm.fsm import FSMState from outlines.generate.generator import sequence_generator @@ -178,7 +177,7 @@ def __call__( prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) - fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)] + fsm_states = [0 for _ in range(batch_size * num_samples)] fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] weights = torch.zeros( (batch_size * num_samples), dtype=torch.float, device=self.device @@ -291,7 +290,7 @@ def stream( prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) - fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)] + fsm_states = [0 for _ in range(batch_size * num_samples)] fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] weights = torch.zeros( (batch_size * num_samples), diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py index aaa90ca0c..aac7a9cd8 100644 --- a/outlines/generate/cfg.py +++ b/outlines/generate/cfg.py @@ -1,6 +1,6 @@ from functools import singledispatch -from outlines.fsm.fsm import CFGFSM +from outlines.fsm.guide import CFGGuide from outlines.generate.api import SequenceGenerator from outlines.models import OpenAI from outlines.models.llamacpp import ( @@ -29,7 +29,7 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera A `SequenceGenerator` instance that generates text. """ - fsm = CFGFSM(cfg_str, model.tokenizer) + fsm = CFGGuide(cfg_str, model.tokenizer) device = model.device generator = SequenceGenerator(fsm, model, sampler, device) diff --git a/outlines/generate/fsm.py b/outlines/generate/fsm.py index 6f9b1b84b..80db350f0 100644 --- a/outlines/generate/fsm.py +++ b/outlines/generate/fsm.py @@ -1,6 +1,6 @@ import interegular -from outlines.fsm.fsm import RegexFSM +from outlines.fsm.guide import RegexGuide from outlines.generate.api import SequenceGenerator from outlines.samplers import Sampler, multinomial @@ -8,7 +8,7 @@ def fsm( model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial() ) -> SequenceGenerator: - fsm = RegexFSM.from_interegular_fsm(fsm, model.tokenizer) + fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) device = model.device generator = SequenceGenerator(fsm, model, sampler, device) return generator diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index 0749df517..ad8ae8537 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -4,10 +4,8 @@ import torch -from outlines.fsm.fsm import FSMState - if TYPE_CHECKING: - from outlines.fsm.fsm import FSM + from outlines.fsm.guide import Guide class ContextLengthExceededError(Exception): @@ -20,17 +18,17 @@ class GenerationState: kv_cache: torch.Tensor logits: torch.Tensor weights: torch.Tensor - fsm_states: List[FSMState] + fsm_states: List[int] def sequence_generator( model: Callable, sampler: Callable, - fsms: List["FSM"], + fsms: List["Guide"], token_ids: torch.Tensor, sequence_weights: torch.Tensor, attention_masks: torch.Tensor, - fsm_states: List[FSMState], + fsm_states: List[int], rng: torch.Generator = torch.Generator(), ) -> Iterator[GenerationState]: """Generates sequences of tokens. @@ -109,8 +107,8 @@ def sequence_generator( def get_next_fsm_states( - fsms: List["FSM"], fsm_states: List[FSMState], next_token_ids: torch.Tensor -) -> List[FSMState]: + fsms: List["Guide"], fsm_states: List[int], next_token_ids: torch.Tensor +) -> List[int]: """ Parameters @@ -126,12 +124,12 @@ def get_next_fsm_states( """ return [ - fsm.next_state(fsm_state, int(token_id[0])) + fsm.get_next_state(fsm_state, int(token_id[0])) for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids) ] -def get_allowed_tokens(fsms: List["FSM"], fsm_states: List[FSMState]) -> torch.Tensor: +def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> torch.Tensor: """Get the new instructions for each sequence from the finite-state machine. Parameters @@ -146,10 +144,12 @@ def get_allowed_tokens(fsms: List["FSM"], fsm_states: List[FSMState]) -> torch.T A nested list that contains the ids of the logits to keep. """ - return [fsm.allowed_token_ids(state) for fsm, state in zip(fsms, fsm_states)] + return [ + fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states) + ] -def is_generation_finished(fsms: List["FSM"], fsm_states: List[FSMState]) -> bool: +def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool: """Determine if the generation is finished. A generation is considered finished if the FSM of every sequence in the @@ -229,7 +229,7 @@ def update_attention_masks( ) -def reorder_fsms(fsms: List["FSM"], ancestors: torch.Tensor) -> List["FSM"]: +def reorder_fsms(fsms: List["Guide"], ancestors: torch.Tensor) -> List["Guide"]: reordered_fsms = [] for ancestor in ancestors: reordered_fsms.append(fsms[ancestor].copy()) @@ -237,9 +237,7 @@ def reorder_fsms(fsms: List["FSM"], ancestors: torch.Tensor) -> List["FSM"]: return reordered_fsms -def reorder_fsm_states( - fsm_states: List[FSMState], ancestors: torch.Tensor -) -> List[FSMState]: +def reorder_fsm_states(fsm_states: List[int], ancestors: torch.Tensor) -> List[int]: reordered_states = [] for ancestor in ancestors: reordered_states.append(fsm_states[ancestor]) diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index 5b4e2bde3..d53d9d3d7 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -1,6 +1,6 @@ from functools import singledispatch -from outlines.fsm.fsm import RegexFSM +from outlines.fsm.guide import RegexGuide from outlines.generate.api import SequenceGenerator from outlines.models import OpenAI from outlines.models.llamacpp import ( @@ -32,7 +32,7 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()): regular expression. """ - fsm = RegexFSM(regex_str, model.tokenizer) + fsm = RegexGuide(regex_str, model.tokenizer) device = model.device generator = SequenceGenerator(fsm, model, sampler, device) diff --git a/outlines/generate/text.py b/outlines/generate/text.py index 40d89bba0..ed389a306 100644 --- a/outlines/generate/text.py +++ b/outlines/generate/text.py @@ -1,6 +1,6 @@ from functools import singledispatch -from outlines.fsm.fsm import StopAtEosFSM +from outlines.fsm.guide import StopAtEOSGuide from outlines.generate import SequenceGenerator from outlines.models import LlamaCpp, OpenAI from outlines.models.llamacpp import LlamaSequenceGenerator @@ -30,7 +30,7 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator: A `SequenceGenerator` instance that generates text. """ - fsm = StopAtEosFSM(model.tokenizer) + fsm = StopAtEOSGuide(model.tokenizer) device = model.device generator = SequenceGenerator(fsm, model, sampler, device) diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 09146af6c..bcffefd8e 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -5,7 +5,7 @@ import torch from numpy.typing import NDArray -from outlines.fsm.fsm import CFGFSM, FSM, FSMState, RegexFSM +from outlines.fsm.guide import CFGGuide, Guide, RegexGuide if TYPE_CHECKING: from llama_cpp import Llama @@ -126,7 +126,7 @@ def llamacpp( class LogitsProcessor: - def __init__(self, tokenizer: LlamaCppTokenizer, fsm: FSM): + def __init__(self, tokenizer: LlamaCppTokenizer, fsm: Guide): """A FSM-based logits processor. Parameters @@ -138,8 +138,8 @@ def __init__(self, tokenizer: LlamaCppTokenizer, fsm: FSM): """ self.tokenizer = tokenizer - self.fsm_state = FSMState(0) - self.fsm: FSM = fsm + self.fsm_state = 0 + self.fsm: Guide = fsm self.is_first_token = True def __call__( @@ -151,9 +151,9 @@ def __call__( self.is_first_token = False else: last_token = input_ids[-1] - self.fsm_state = self.fsm.next_state(self.fsm_state, last_token) + self.fsm_state = self.fsm.get_next_state(self.fsm_state, last_token) - allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state) + allowed_tokens = self.fsm.get_next_instruction(self.fsm_state).tokens mask = torch.full((scores.shape[-1],), -math.inf, device="cpu").numpy() mask[allowed_tokens] = 0 @@ -177,7 +177,7 @@ def __init__(self, regex_string: str, tokenizer: LlamaCppTokenizer): An instance of `Tokenizer` """ - fsm = RegexFSM(regex_string, tokenizer) + fsm = RegexGuide(regex_string, tokenizer) super().__init__(tokenizer, fsm) @@ -193,5 +193,5 @@ def __init__(self, cfg_str: str, tokenizer: LlamaCppTokenizer): An instance of `Tokenizer` """ - fsm = CFGFSM(cfg_str, tokenizer) + fsm = CFGGuide(cfg_str, tokenizer) super().__init__(tokenizer, fsm) diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py index 847fe63cd..f95a04ad2 100644 --- a/outlines/serve/vllm.py +++ b/outlines/serve/vllm.py @@ -31,7 +31,7 @@ import torch from pydantic import BaseModel -from outlines.fsm.fsm import RegexFSM +from outlines.fsm.guide import RegexGuide from outlines.fsm.json_schema import build_regex_from_schema @@ -61,7 +61,7 @@ def __init__(self, regex_string, llm): ) tokenizer = self.adapt_tokenizer(tokenizer=tokenizer) - fsm = RegexFSM(regex_string, tokenizer) + fsm = RegexGuide(regex_string, tokenizer) self.fsm = fsm def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: diff --git a/tests/benchmark/conftest.py b/tests/benchmark/conftest.py index edf2ff614..902d5d6eb 100644 --- a/tests/benchmark/conftest.py +++ b/tests/benchmark/conftest.py @@ -1,7 +1,7 @@ import pytest from transformers import AutoTokenizer -from outlines.fsm.fsm import RegexFSM +from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer @@ -13,5 +13,5 @@ def tokenizer(): @pytest.fixture def ensure_numba_compiled(tokenizer): - RegexFSM("a", tokenizer) + RegexGuide("a", tokenizer) return True diff --git a/tests/benchmark/test_benchmark_json_schema.py b/tests/benchmark/test_benchmark_json_schema.py index 3c3a4d3c3..33f3f5b16 100644 --- a/tests/benchmark/test_benchmark_json_schema.py +++ b/tests/benchmark/test_benchmark_json_schema.py @@ -4,7 +4,7 @@ outlines.disable_cache() -from outlines.fsm.fsm import RegexFSM # noqa: E402 +from outlines.fsm.guide import RegexGuide # noqa: E402 from outlines.fsm.json_schema import build_regex_from_schema # noqa: E402 simple_schema = """{ @@ -86,7 +86,7 @@ def test_benchmark_json_schema_to_fsm( schema = schemas[schema_name] regex = build_regex_from_schema(schema) benchmark.pedantic( - RegexFSM, + RegexGuide, args=(regex, tokenizer), rounds=8, ) diff --git a/tests/benchmark/test_benchmark_regex_fsm.py b/tests/benchmark/test_benchmark_regex_fsm.py index 18673d26e..e9e45052a 100644 --- a/tests/benchmark/test_benchmark_regex_fsm.py +++ b/tests/benchmark/test_benchmark_regex_fsm.py @@ -4,7 +4,7 @@ outlines.disable_cache() -from outlines.fsm.fsm import RegexFSM # noqa: E402 +from outlines.fsm.guide import RegexGuide # noqa: E402 regex_samples = { "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", @@ -26,7 +26,7 @@ def test_benchmark_regex_to_fsm( """Benchmark converting regex to FSM""" regex_str = regex_samples[regex_name] benchmark.pedantic( - RegexFSM, + RegexGuide, args=(regex_str, tokenizer), rounds=8, ) diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_guide.py similarity index 52% rename from tests/fsm/test_fsm.py rename to tests/fsm/test_guide.py index 289f3fd8d..4be5259d9 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_guide.py @@ -1,6 +1,6 @@ import pytest -from outlines.fsm.fsm import CFGFSM, RegexFSM, StopAtEosFSM +from outlines.fsm.guide import CFGGuide, Generate, RegexGuide, StopAtEOSGuide, Write def test_stop_at_eos(): @@ -8,12 +8,18 @@ class MockTokenizer: vocabulary = {"a": 1, "eos": 2} eos_token_id = 2 - fsm = StopAtEosFSM(MockTokenizer()) + fsm = StopAtEOSGuide(MockTokenizer()) - assert fsm.allowed_token_ids(fsm.start_state) == [1, 2] - assert fsm.allowed_token_ids(fsm.final_state) == [2] - assert fsm.next_state(fsm.start_state, 2) == fsm.final_state - assert fsm.next_state(fsm.start_state, 1) == fsm.start_state + instruction = fsm.get_next_instruction(fsm.start_state) + assert isinstance(instruction, Generate) + assert instruction.tokens == [1, 2] + + instruction = fsm.get_next_instruction(fsm.final_state) + assert isinstance(instruction, Write) + assert instruction.tokens == [2] + + assert fsm.get_next_state(fsm.start_state, 2) == fsm.final_state + assert fsm.get_next_state(fsm.start_state, 1) == fsm.start_state assert fsm.is_final_state(fsm.start_state) is False assert fsm.is_final_state(fsm.final_state) is True @@ -30,7 +36,7 @@ def convert_token_to_string(self, token): regex_str = "[1-9]" with pytest.raises(ValueError, match="The vocabulary"): - RegexFSM(regex_str, MockTokenizer()) + RegexGuide(regex_str, MockTokenizer()) def test_regex(): @@ -44,12 +50,16 @@ def convert_token_to_string(self, token): regex_str = "[1-9]" tokenizer = MockTokenizer() - fsm = RegexFSM(regex_str, tokenizer) + fsm = RegexGuide(regex_str, tokenizer) assert fsm.states_to_token_maps == {0: {1: 1}} - assert fsm.allowed_token_ids(state=0) == [1] - assert fsm.next_state(state=0, token_id=1) == 1 - assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1 + + instruction = fsm.get_next_instruction(0) + assert isinstance(instruction, Generate) + assert instruction.tokens == [1] + + assert fsm.get_next_state(state=0, token_id=1) == 1 + assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 assert fsm.is_final_state(0) is False @@ -70,13 +80,13 @@ def convert_token_to_string(self, token): regex_str = r"`\n(\.\n)?`\n" tokenizer = MockTokenizer() - fsm = RegexFSM(regex_str, tokenizer) + fsm = RegexGuide(regex_str, tokenizer) - state = fsm.next_state(state=4, token_id=103) + state = fsm.get_next_state(state=4, token_id=103) assert state == 5 assert fsm.is_final_state(state) - state = fsm.next_state(state=5, token_id=103) + state = fsm.get_next_state(state=5, token_id=103) assert state == 5 assert fsm.is_final_state(-1) @@ -104,30 +114,40 @@ def decode(self, token_ids): expr: "{" expr "}" | "[" expr "]" | """ tokenizer = MockTokenizer() - fsm = CFGFSM(cfg_str, tokenizer) + fsm = CFGGuide(cfg_str, tokenizer) - assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3, 5} - state = fsm.next_state(state=fsm.start_state, token_id=1) + instruction = fsm.get_next_instruction(fsm.start_state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 3, 5} + state = fsm.get_next_state(state=fsm.start_state, token_id=1) assert fsm.generation == "{" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} - state = fsm.next_state(state=state, token_id=3) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 2, 3} + state = fsm.get_next_state(state=state, token_id=3) assert fsm.generation == "{[" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {1, 3, 4} - state = fsm.next_state(state=state, token_id=4) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 3, 4} + state = fsm.get_next_state(state=state, token_id=4) assert fsm.generation == "{[]" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {2} - state = fsm.next_state(state=state, token_id=2) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {2} + state = fsm.get_next_state(state=state, token_id=2) assert fsm.generation == "{[]}" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {5} - state = fsm.next_state(state=state, token_id=5) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Write) + assert set(instruction.tokens) == {5} + state = fsm.get_next_state(state=state, token_id=5) assert fsm.generation == "{[]}" assert fsm.is_final_state(state) @@ -155,26 +175,34 @@ def decode(self, token_ids): subexpr: expr | """ tokenizer = MockTokenizer() - fsm = CFGFSM(cfg_str, tokenizer) + fsm = CFGGuide(cfg_str, tokenizer) - assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1} - state = fsm.next_state(state=fsm.start_state, token_id=1) + instruction = fsm.get_next_instruction(fsm.start_state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1} + state = fsm.get_next_state(state=fsm.start_state, token_id=1) assert fsm.generation == "(" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {1, 2} - state = fsm.next_state(state=state, token_id=2) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 2} + state = fsm.get_next_state(state=state, token_id=2) assert fsm.generation == "()" assert not fsm.is_final_state(state) # possible to continue or terminate - assert set(fsm.allowed_token_ids(state=state)) == {1, 3} - state = fsm.next_state(state=state, token_id=3) # feed eos + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 3} + state = fsm.get_next_state(state=state, token_id=3) # feed eos assert fsm.generation == "()" assert fsm.is_final_state(state) # once eos generated, can only terminate - assert set(fsm.allowed_token_ids(state=state)) == {3} + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Write) + assert set(instruction.tokens) == {3} def test_cfg_ignore_directive(): @@ -201,42 +229,56 @@ def decode(self, token_ids): %ignore WS """ tokenizer = MockTokenizer() - fsm = CFGFSM(cfg_str, tokenizer) + fsm = CFGGuide(cfg_str, tokenizer) state = 0 - assert set(fsm.allowed_token_ids(state=0)) == {1, 2} - state = fsm.next_state(state=0, token_id=2) + instruction = fsm.get_next_instruction(0) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 2} + state = fsm.get_next_state(state=0, token_id=2) assert fsm.generation == " " assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=0)) == {1, 2} - state = fsm.next_state(state=0, token_id=1) + instruction = fsm.get_next_instruction(0) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 2} + state = fsm.get_next_state(state=0, token_id=1) assert fsm.generation == " a" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} - state = fsm.next_state(state=state, token_id=2) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 2, 3} + state = fsm.get_next_state(state=state, token_id=2) assert fsm.generation == " a " assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} - state = fsm.next_state(state=state, token_id=2) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 2, 3} + state = fsm.get_next_state(state=state, token_id=2) assert fsm.generation == " a " assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} - state = fsm.next_state(state=state, token_id=1) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 2, 3} + state = fsm.get_next_state(state=state, token_id=1) assert fsm.generation == " a a" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} - state = fsm.next_state(state=state, token_id=3) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 2, 3} + state = fsm.get_next_state(state=state, token_id=3) assert fsm.generation == " a a" assert fsm.is_final_state(state) # once eos generated, can only terminate - assert set(fsm.allowed_token_ids(state=state)) == {3} + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Write) + assert set(instruction.tokens) == {3} def test_cfg_multitoken_terminal(): @@ -261,23 +303,29 @@ def decode(self, token_ids): S: "aa" | "bb" """ tokenizer = MockTokenizer() - fsm = CFGFSM(cfg_str, tokenizer) + fsm = CFGGuide(cfg_str, tokenizer) - assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 2} + instruction = fsm.get_next_instruction(fsm.start_state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 2} assert fsm.reset_state # starting new regex - state = fsm.next_state(state=fsm.start_state, token_id=1) + state = fsm.get_next_state(state=fsm.start_state, token_id=1) assert fsm.generation == "a" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {1} + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1} assert not fsm.reset_state # continuing current regex - state = fsm.next_state(state=state, token_id=1) + state = fsm.get_next_state(state=state, token_id=1) assert fsm.generation == "aa" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {3} + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Write) + assert set(instruction.tokens) == {3} assert not fsm.reset_state # completing current regex - state = fsm.next_state(state=state, token_id=3) + state = fsm.get_next_state(state=state, token_id=3) assert fsm.generation == "aa" assert fsm.is_final_state(state) @@ -304,29 +352,39 @@ def decode(self, token_ids): s: "(" s ")" | /a+/ """ tokenizer = MockTokenizer() - fsm = CFGFSM(cfg_str, tokenizer) + fsm = CFGGuide(cfg_str, tokenizer) - assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3} - state = fsm.next_state(state=fsm.start_state, token_id=1) + instruction = fsm.get_next_instruction(fsm.start_state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 3} + state = fsm.get_next_state(state=fsm.start_state, token_id=1) assert fsm.generation == "(" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {1, 3} - state = fsm.next_state(state=state, token_id=3) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {1, 3} + state = fsm.get_next_state(state=state, token_id=3) assert fsm.generation == "(a" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {2, 3} - state = fsm.next_state(state=state, token_id=3) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {2, 3} + state = fsm.get_next_state(state=state, token_id=3) assert fsm.generation == "(aa" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {2, 3} - state = fsm.next_state(state=state, token_id=2) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Generate) + assert set(instruction.tokens) == {2, 3} + state = fsm.get_next_state(state=state, token_id=2) assert fsm.generation == "(aa)" assert not fsm.is_final_state(state) - assert set(fsm.allowed_token_ids(state=state)) == {4} - state = fsm.next_state(state=state, token_id=4) + instruction = fsm.get_next_instruction(state) + assert isinstance(instruction, Write) + assert set(instruction.tokens) == {4} + state = fsm.get_next_state(state=state, token_id=4) assert fsm.generation == "(aa)" assert fsm.is_final_state(state) diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py index d3a3763e4..5a2edf8dc 100644 --- a/tests/generate/test_generator.py +++ b/tests/generate/test_generator.py @@ -4,7 +4,7 @@ import pytest import torch -from outlines.fsm.fsm import FSMState +from outlines.fsm.guide import Generate from outlines.generate.api import SequenceGenerator from outlines.generate.generator import ( bias_logits, @@ -21,11 +21,11 @@ def test_sequence_generator_class(): class MockFSM: first_state = 0 - def next_state(self, state, next_token_ids): + def get_next_state(self, state, next_token_ids): return 4 - def allowed_token_ids(self, *_): - return [4] + def get_next_instruction(self, *_): + return Generate([4]) def is_final_state(self, _): return True @@ -77,11 +77,11 @@ def __call__(self, biased_logits, *_): def test_sequence_generator_1d_single_iteration(): class MockFSM: - def next_state(self, state, next_token_ids): + def get_next_state(self, state, next_token_ids): return 0 - def allowed_token_ids(self, _): - return [0, 1, 2, 3] + def get_next_instruction(self, _): + return Generate([0, 1, 2, 3]) def is_final_state(self, _): return True @@ -132,11 +132,11 @@ def sampler(biased_logits, *_): def test_sequence_generator_1d_several_iterations(): class MockFSM: - def next_state(self, state, next_token_ids): - return FSMState(state + 1) + def get_next_state(self, state, next_token_ids): + return state + 1 - def allowed_token_ids(self, _): - return [0, 1, 2, 3] + def get_next_instruction(self, _): + return Generate([0, 1, 2, 3]) def is_final_state(self, state): if state < 2: @@ -194,11 +194,11 @@ def sampler(biased_logits, *_): def test_sequence_generator_2d_single_iteration(): class MockFSM: - def next_state(self, state, next_token_ids): + def get_next_state(self, state, next_token_ids): return 0 - def allowed_token_ids(self, _): - return [0, 1, 2, 3] + def get_next_instruction(self, _): + return Generate([0, 1, 2, 3]) def is_final_state(self, _): return True @@ -260,11 +260,11 @@ def sampler(biased_logits, *_): def test_sequence_generator_2d_several_iterations(): class MockFSM: - def next_state(self, state, next_token_ids): - return FSMState(state + 1) + def get_next_state(self, state, next_token_ids): + return state + 1 - def allowed_token_ids(self, _): - return [0, 1, 2, 3] + def get_next_instruction(self, _): + return Generate([0, 1, 2, 3]) def is_final_state(self, state): if state < 2: @@ -337,7 +337,7 @@ def sampler(biased_logits, *_): def test_get_next_fsm_states(): class MockFSM: - def next_state(self, state, next_token_ids): + def get_next_state(self, state, next_token_ids): return 0 def copy(self): @@ -352,10 +352,10 @@ def copy(self): assert result == [0, 0] -def test_get_allowed_token_idss(): +def test_get_get_next_instructions(): class MockFSM: - def allowed_token_ids(self, _): - return [1, 2, 3, 4] + def get_next_instruction(self, _): + return Generate([1, 2, 3, 4]) result = get_allowed_tokens([MockFSM()], [0]) assert result == [[1, 2, 3, 4]] diff --git a/tests/test_grammars.py b/tests/test_grammars.py index bee5c5378..3d704f094 100644 --- a/tests/test_grammars.py +++ b/tests/test_grammars.py @@ -1,7 +1,7 @@ import pytest import outlines.grammars as grammars -from outlines.fsm.fsm import CFGFSM +from outlines.fsm.guide import CFGGuide @pytest.mark.parametrize("grammar", [grammars.json, grammars.arithmetic]) @@ -27,5 +27,5 @@ def decode(self, token_ids): s: "(" s ")" | /a+/ """ tokenizer = MockTokenizer() - fsm = CFGFSM(cfg_str, tokenizer) - assert isinstance(fsm, CFGFSM) + fsm = CFGGuide(cfg_str, tokenizer) + assert isinstance(fsm, CFGGuide)