|
1 | 1 |
|
2 | 2 | from pathlib import Path |
| 3 | +from typing import Callable, Optional, Union |
3 | 4 |
|
4 | 5 | import pytest |
5 | 6 |
|
6 | | -from kaldi_active_grammar import Compiler, KaldiRule |
| 7 | +from kaldi_active_grammar import Compiler, KaldiRule, NativeWFST, WFST |
7 | 8 | from tests.helpers import * |
8 | 9 |
|
9 | 10 |
|
10 | 11 | class TestGrammar: |
11 | 12 |
|
12 | 13 | @pytest.fixture(autouse=True) |
13 | | - def setup(self, monkeypatch): |
| 14 | + def setup(self, monkeypatch, audio_generator): |
14 | 15 | monkeypatch.chdir(Path(__file__).parent) # Where model is |
15 | 16 | self.compiler = Compiler() |
16 | 17 | self.decoder = self.compiler.init_decoder() |
| 18 | + self.audio_generator = audio_generator |
17 | 19 |
|
18 | | - def test_simple_rule_creation_and_compilation(self, audio_generator): |
19 | | - """ Test basic rule creation and compilation """ |
20 | | - rule = KaldiRule(self.compiler, 'TestRule') |
21 | | - assert rule.name == 'TestRule' |
| 20 | + def make_rule(self, name: str, build_func: Callable[[Union[NativeWFST, WFST]], None]): |
| 21 | + rule = KaldiRule(self.compiler, name) |
| 22 | + assert rule.name == name |
22 | 23 | assert rule.fst is not None |
23 | | - |
24 | | - fst = rule.fst |
25 | | - initial_state = fst.add_state(initial=True) |
26 | | - final_state = fst.add_state(final=True) |
27 | | - fst.add_arc(initial_state, final_state, 'hello') |
28 | | - |
| 24 | + build_func(rule.fst) |
29 | 25 | rule.compile() |
30 | 26 | assert rule.compiled |
31 | 27 | rule.load() |
32 | 28 | assert rule.loaded |
| 29 | + return rule |
33 | 30 |
|
34 | | - self.decoder.decode(audio_generator("hello"), True, [True]) |
| 31 | + def decode(self, text: str, kaldi_rules_activity: list[bool], expected_rule: KaldiRule, expected_words_are_dictation_mask: Optional[list[bool]] = None): |
| 32 | + self.decoder.decode(self.audio_generator(text), True, kaldi_rules_activity) |
35 | 33 |
|
36 | 34 | output, info = self.decoder.get_output() |
37 | 35 | assert isinstance(output, str) |
38 | 36 | assert len(output) > 0 |
39 | 37 | assert_info_shape(info) |
40 | 38 |
|
41 | 39 | recognized_rule, words, words_are_dictation_mask = self.compiler.parse_output(output) |
42 | | - assert recognized_rule == rule |
43 | | - assert words == ['hello'] |
44 | | - assert words_are_dictation_mask == [False] |
| 40 | + assert recognized_rule == expected_rule |
| 41 | + assert words == text.split() |
| 42 | + if expected_words_are_dictation_mask is None: |
| 43 | + expected_words_are_dictation_mask = [False] * len(words) |
| 44 | + assert words_are_dictation_mask == expected_words_are_dictation_mask |
| 45 | + |
| 46 | + def test_simple_rule(self): |
| 47 | + def _build(fst): |
| 48 | + initial_state = fst.add_state(initial=True) |
| 49 | + final_state = fst.add_state(final=True) |
| 50 | + fst.add_arc(initial_state, final_state, 'hello') |
| 51 | + rule = self.make_rule('TestRule', _build) |
| 52 | + self.decode("hello", [True], rule) |
0 commit comments