From 11f8cd39148dbda4021a195b15d8cf1b083d4e8b Mon Sep 17 00:00:00 2001 From: Kaustubh Date: Tue, 22 Oct 2024 16:59:04 +0530 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Brandon T. Willard <971601+brandonwillard@users.noreply.github.com> Co-authored-by: Victoria Terenina --- python/outlines_core/fsm/outlines_core_rs.pyi | 2 +- src/python_bindings/mod.rs | 26 +++++++------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index 4f38469..3dfe41e 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -88,7 +88,7 @@ class Vocabulary: ... class PyVocabIndex: - def get_next_instruction(self, state: int): + def get_allowed_tokens(self, state: int): """ Return the next instruction for guided generation. """ diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 849e6c8..62ebd23 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -65,7 +65,7 @@ impl PyVocabIndex { ) -> PyResult { let mut states_to_token_subsets: HashMap> = HashMap::new(); let mut seen: HashSet = HashSet::new(); - let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); + let mut next_states: HashSet = HashSet::from([fsm_info.initial]); let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, @@ -99,18 +99,10 @@ impl PyVocabIndex { seen.insert(start_state); } - let mut is_valid = false; - for token_id_end_states in states_to_token_subsets.values() { - for end_state in token_id_end_states.values() { - if fsm_info.finals.contains(end_state) { - is_valid = true; - break; - } - } - if is_valid { - break; - } - } + let is_valid = states_to_token_subsets + .values() + .flat_map(|token_id_end_states| token_id_end_states.values()) + .any(|end_state| fsm_info.finals.contains(end_state)); if is_valid { Ok(Self { @@ -126,10 +118,10 @@ impl PyVocabIndex { } } - fn get_next_instruction(&mut self, state: u32) -> Vec { - let default = HashMap::new(); - let res = self.states_to_token_subsets.get(&state).unwrap_or(&default); - res.keys().cloned().collect() + fn get_allowed_tokens(&mut self, state: u32) -> Vec { + self.states_to_token_subsets + .get(&state) + .map_or_else(HashSet::new, |res| res.keys().cloned().collect()) } fn get_next_state(&mut self, state: u32, token_id: u32) -> i32 {