Skip to content

Commit

Permalink
test fsm_union and walk_fsm
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Oct 14, 2024
1 parent 8c16102 commit ab0cc4f
Showing 1 changed file with 101 additions and 0 deletions.
101 changes: 101 additions & 0 deletions tests/fsm/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,104 @@ def test_sequential_parse_example(cleanup_lark_import):

if i + 1 == len(input_tokens):
assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"])


# TODO: Remove once fsm_union and walk_fsm are implemented in Outlines-Core
import interegular # noqa

from outlines.fsm.parsing import fsm_union, walk_fsm # noqa


def test_outlines_interegular_union_consistency():
fsm0 = interegular.parse_pattern(r"abc").to_fsm()
fsm1 = interegular.parse_pattern(r"WXYZ").to_fsm()
fsm2 = interegular.parse_pattern(r"12345").to_fsm()

interegular_unioned_fsm = fsm0 | fsm1 | fsm2
outlines_unioned_fsm, _ = fsm_union([fsm0, fsm1, fsm2])

assert list(outlines_unioned_fsm.strings()) == list(
interegular_unioned_fsm.strings()
)


def _reconstruct_fsms(fsm, fsms_to_trans_finals):
"""Reconstruct the original fsms for testing purposes"""
reconstructed_fsms = []
for transitions, finals, state_map in fsms_to_trans_finals.values():
inv_state_map = {new: orig for orig, news in state_map.items() for new in news}
states = set(inv_state_map.values())
initial = inv_state_map.get(fsm.initial) or next(
(orig for orig, news in state_map.items() if fsm.initial in news), None
)
finals = {inv_state_map[s] for s in finals}

transition_map = {}
alphabet = {}
for trans_id, (from_state, to_state) in enumerate(transitions):
orig_from, orig_to = inv_state_map[from_state], inv_state_map[to_state]
# Collect symbols associated with the transition
symbols = {
symbol
for trans, dest in fsm.map.get(from_state, {}).items()
if dest == to_state
for symbol in fsm.alphabet.by_transition.get(trans, [])
}
if symbols:
# NOTE: THIS RECONSTRUCTOR DOESNT WORK FOR MORE THAN ONE TRANSITION PER SYMBOL
assert len(symbols) == 1
symbol = list(symbols)[0]
alphabet[symbol] = trans_id
transition_map.setdefault(orig_from, {})[trans_id] = orig_to

reconstructed_fsms.append(
interegular.fsm.FSM(
alphabet=interegular.fsm.Alphabet(alphabet),
states=frozenset(states),
initial=initial,
finals=frozenset(finals),
map=transition_map,
__no_validation__=True,
)
)
return reconstructed_fsms


def test_fsm_to_trans_finals_reconstruction():
"""Assert that _fsms_to_trans_finals is correct by reconstructing original fsms"""
fsm0 = interegular.parse_pattern(r"abc").to_fsm()
fsm1 = interegular.parse_pattern(r"XYZ").to_fsm()
fsm2 = interegular.parse_pattern(r"12345").to_fsm()

fsm, _fsms_to_trans_finals = fsm_union([fsm0, fsm1, fsm2])

reconstructed = _reconstruct_fsms(fsm, _fsms_to_trans_finals)

# assert reconstruction equivalent
assert list(fsm0.strings()) == list(reconstructed[0].strings())
assert list(fsm1.strings()) == list(reconstructed[1].strings())
assert list(fsm2.strings()) == list(reconstructed[2].strings())


def test_walk_fsm():
fsm = interegular.parse_pattern(r"abc*d").to_fsm()
# convert to BetterFSM
fsm = fsm_union([fsm])[0]

# if match, produce equivalent number of states, assert state can terminate
transitions = [fsm.alphabet[letter] for letter in "abcccd"]
accepted_states = walk_fsm(fsm, transitions, fsm.initial, full_match=True)
assert len(accepted_states) == len(transitions)
assert accepted_states[-1] in fsm.finals

# if no match, assert empty
accepted_states = walk_fsm(
fsm, [fsm.alphabet[letter] for letter in "b"], fsm.initial, full_match=True
)
assert accepted_states == []

# if full_match, but last state not present, assert empty
accepted_states = walk_fsm(
fsm, [fsm.alphabet[letter] for letter in "abc"], fsm.initial, full_match=True
)
assert accepted_states == []

0 comments on commit ab0cc4f

Please sign in to comment.