Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Brandon T. Willard <[email protected]>
Co-authored-by: Victoria Terenina <[email protected]>
  • Loading branch information
3 people authored Oct 22, 2024
1 parent c8e03cc commit 11f8cd3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 18 deletions.
2 changes: 1 addition & 1 deletion python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class Vocabulary:
...

class PyVocabIndex:
def get_next_instruction(self, state: int):
def get_allowed_tokens(self, state: int):
"""
Return the next instruction for guided generation.
"""
Expand Down
26 changes: 9 additions & 17 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl PyVocabIndex {
) -> PyResult<Self> {
let mut states_to_token_subsets: HashMap<u32, HashMap<u32, u32>> = HashMap::new();
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from_iter(vec![fsm_info.initial]);
let mut next_states: HashSet<State> = HashSet::from([fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
Expand Down Expand Up @@ -99,18 +99,10 @@ impl PyVocabIndex {
seen.insert(start_state);
}

let mut is_valid = false;
for token_id_end_states in states_to_token_subsets.values() {
for end_state in token_id_end_states.values() {
if fsm_info.finals.contains(end_state) {
is_valid = true;
break;
}
}
if is_valid {
break;
}
}
let is_valid = states_to_token_subsets
.values()
.flat_map(|token_id_end_states| token_id_end_states.values())
.any(|end_state| fsm_info.finals.contains(end_state));

if is_valid {
Ok(Self {
Expand All @@ -126,10 +118,10 @@ impl PyVocabIndex {
}
}

fn get_next_instruction(&mut self, state: u32) -> Vec<u32> {
let default = HashMap::new();
let res = self.states_to_token_subsets.get(&state).unwrap_or(&default);
res.keys().cloned().collect()
fn get_allowed_tokens(&mut self, state: u32) -> Vec<u32> {
self.states_to_token_subsets
.get(&state)
.map_or_else(HashSet::new, |res| res.keys().cloned().collect())
}

fn get_next_state(&mut self, state: u32, token_id: u32) -> i32 {
Expand Down

0 comments on commit 11f8cd3

Please sign in to comment.