Skip to content

Commit 109e294

Browse files
committed
Refactor tests
1 parent 1f1fdb0 commit 109e294

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

tests/test_grammar.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,52 @@
11

22
from pathlib import Path
3+
from typing import Callable, Optional, Union
34

45
import pytest
56

6-
from kaldi_active_grammar import Compiler, KaldiRule
7+
from kaldi_active_grammar import Compiler, KaldiRule, NativeWFST, WFST
78
from tests.helpers import *
89

910

1011
class TestGrammar:
1112

1213
@pytest.fixture(autouse=True)
13-
def setup(self, monkeypatch):
14+
def setup(self, monkeypatch, audio_generator):
1415
monkeypatch.chdir(Path(__file__).parent) # Where model is
1516
self.compiler = Compiler()
1617
self.decoder = self.compiler.init_decoder()
18+
self.audio_generator = audio_generator
1719

18-
def test_simple_rule_creation_and_compilation(self, audio_generator):
19-
""" Test basic rule creation and compilation """
20-
rule = KaldiRule(self.compiler, 'TestRule')
21-
assert rule.name == 'TestRule'
20+
def make_rule(self, name: str, build_func: Callable[[Union[NativeWFST, WFST]], None]):
21+
rule = KaldiRule(self.compiler, name)
22+
assert rule.name == name
2223
assert rule.fst is not None
23-
24-
fst = rule.fst
25-
initial_state = fst.add_state(initial=True)
26-
final_state = fst.add_state(final=True)
27-
fst.add_arc(initial_state, final_state, 'hello')
28-
24+
build_func(rule.fst)
2925
rule.compile()
3026
assert rule.compiled
3127
rule.load()
3228
assert rule.loaded
29+
return rule
3330

34-
self.decoder.decode(audio_generator("hello"), True, [True])
31+
def decode(self, text: str, kaldi_rules_activity: list[bool], expected_rule: KaldiRule, expected_words_are_dictation_mask: Optional[list[bool]] = None):
32+
self.decoder.decode(self.audio_generator(text), True, kaldi_rules_activity)
3533

3634
output, info = self.decoder.get_output()
3735
assert isinstance(output, str)
3836
assert len(output) > 0
3937
assert_info_shape(info)
4038

4139
recognized_rule, words, words_are_dictation_mask = self.compiler.parse_output(output)
42-
assert recognized_rule == rule
43-
assert words == ['hello']
44-
assert words_are_dictation_mask == [False]
40+
assert recognized_rule == expected_rule
41+
assert words == text.split()
42+
if expected_words_are_dictation_mask is None:
43+
expected_words_are_dictation_mask = [False] * len(words)
44+
assert words_are_dictation_mask == expected_words_are_dictation_mask
45+
46+
def test_simple_rule(self):
47+
def _build(fst):
48+
initial_state = fst.add_state(initial=True)
49+
final_state = fst.add_state(final=True)
50+
fst.add_arc(initial_state, final_state, 'hello')
51+
rule = self.make_rule('TestRule', _build)
52+
self.decode("hello", [True], rule)

0 commit comments

Comments
 (0)