Skip to content

Commit

Permalink
feat: Add .remove(key) method, closes #1
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Aug 22, 2024
1 parent e51ddba commit c4707ae
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ cargo doc --all --all-features
### ⏱️ Benchmark

Running benchmarks requires to enable rust nightly: `rustup default nightly`
Running benchmarks requires to enable rust nightly: `rustup override set nightly`

```bash
cargo bench
Expand Down
116 changes: 116 additions & 0 deletions src/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,122 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
self.root.insert(key, value);
}

/// Removes a key from the trie, if it exists.
///
/// # Example
///
/// ```rust
/// use ptrie::Trie;
///
/// let mut t = Trie::new();
/// let data = "test".bytes();
/// t.insert(data.clone(), 42);
/// assert!(t.contains_key(data.clone()));
///
/// t.remove(data.clone());
/// assert!(!t.contains_key(data));
/// t.remove("toto".bytes());
/// ```
pub fn remove<I: Iterator<Item = K>>(&mut self, key: I) -> Option<V> {
Self::remove_recursive(&mut self.root, key)
}

fn remove_recursive<I: Iterator<Item = K>>(node: &mut TrieNode<K, V>, mut key: I) -> Option<V> {
if let Some(k) = key.next() {
if let Some(index) = node.children.iter().position(|(key_part, _)| key_part == &k) {
let child = &mut node.children[index];
let result = Self::remove_recursive(&mut child.1, key);

// If the child node is now empty, remove it
if child.1.value.is_none() && child.1.children.is_empty() {
node.children.remove(index);
}

return result;
} else {
// Key part not found
return None;
}
}

// Reached the node corresponding to the full key
node.value.take()
}

// pub fn remove<I: Iterator<Item = K>>(&mut self, key: I) -> Option<V> {
// let root = &mut self.root;
// self.remove_recursive(root, key)
// }

// fn remove_recursive<I: Iterator<Item = K>>(
// &mut self,
// node: &mut TrieNode<K, V>,
// mut key: I,
// ) -> Option<V> {
// if let Some(k) = key.next() {
// // If the next part of the key exists in the children, recurse deeper
// if let Some((_, child)) = node.children.iter_mut().find(|(key_part, _)| key_part == &k) {
// let result = self.remove_recursive(child, key);

// // If the child is now empty (no value and no children), remove it
// if child.value.is_none() && child.children.is_empty() {
// node.children.retain(|(key_part, _)| key_part != &k);
// }
// return result;
// } else {
// // If the key is not found, return None
// return None;
// }
// }

// // We've reached the node corresponding to the full key
// node.value.take()
// }

// /// Removes a key from the trie
// ///
// /// # Example
// ///
// /// ```rust
// /// use ptrie::Trie;
// ///
// /// let mut t = Trie::new();
// /// let data = "test".bytes();
// /// t.insert(data.clone(), 42);
// /// assert!(t.contains_key(data.clone()));
// /// t.remove(data.clone());
// /// assert!(!t.contains_key(data));
// /// ```
// pub fn remove<I: Iterator<Item = K>>(&mut self, key: I) -> Option<V> {
// let mut current = &mut self.root;
// let mut path = Vec::new();

// // Traverse the trie to find the node
// for k in key {
// if let Some(index) = current.children.iter().position(|(ckey, _)| ckey == &k) {
// path.push((current, index));
// current = &mut current.children[index].1;
// } else {
// return None; // Key not found
// }
// }

// // Remove the value from the leaf node
// let value = current.value.take();

// // Remove unnecessary nodes
// while let Some((parent, child_index)) = path.pop() {
// if current.children.is_empty() && current.value.is_none() {
// parent.children.remove(child_index);
// } else {
// break;
// }
// current = parent;
// }

// value
// }

/// Finds the node in the `Trie` for a given key
///
/// Internal API
Expand Down
22 changes: 21 additions & 1 deletion tests/trie_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,32 @@ mod tests {

t.insert(test.clone(), String::from("test"));
t.insert(tes.clone(), String::from("tes"));
for (k, v) in t.iter() {
// TODO: for (k, v) in &t {
for (k, v) in &t {
assert!(std::str::from_utf8(&k).unwrap().starts_with("tes"));
assert!(v.starts_with("tes"));
}
}

#[test]
fn test_remove() {
let mut trie = Trie::new();
trie.insert("hello".bytes(), 1);
trie.insert("hell".bytes(), 2);
trie.insert("h".bytes(), 3);

assert_eq!(trie.remove("hello".bytes()), Some(1));
assert_eq!(trie.get("hello".bytes()), None);
assert_eq!(trie.get("hell".bytes()), Some(&2));
assert_eq!(trie.get("h".bytes()), Some(&3));

assert_eq!(trie.remove("h".bytes()), Some(3));
assert_eq!(trie.get("h".bytes()), None);
assert_eq!(trie.get("hell".bytes()), Some(&2));

assert_eq!(trie.remove("nonexistent".bytes()), None);
}

#[cfg(feature = "serde")]
#[test]
fn serde_serialize() {
Expand Down

0 comments on commit c4707ae

Please sign in to comment.