diff --git a/src/bst.rs b/src/bst.rs index 33fc57a..a8ee6fe 100644 --- a/src/bst.rs +++ b/src/bst.rs @@ -33,8 +33,8 @@ impl BinarySearchTree { } } - pub fn find(&mut self, data: T) -> Option { - self.root.take().map(|mut root| root.find(data))? + pub fn find(&self, data: T) -> Option { + self.root.as_ref().map(|root| root.find(data))? } } @@ -66,11 +66,43 @@ impl Node { } } - fn find(&mut self, data: T) -> Option { + fn find(&self, data: T) -> Option { match self.data.cmp(&data) { - Less => self.right.take().map(|mut right| right.find(data))?, + Less => self.right.as_ref().map(|right| right.find(data))?, Equal => Some(data), - Greater => self.left.take().map(|mut left| left.find(data))?, + Greater => self.left.as_ref().map(|left| left.find(data))?, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn insert_and_find() { + let mut tree = BinarySearchTree::new(); + tree.insert(1); + assert_eq!(tree.find(1), Some(1)); + } + + #[test] + fn insert_three_values_find_third() { + let mut tree = BinarySearchTree::new(); + for num in [1, 2, 3] { + tree.insert(num); + } + assert_eq!(tree.find(3), Some(3)); + } + + #[test] + fn multiple_finds() { + let mut tree = BinarySearchTree::new(); + for num in [1, 2, 3] { + tree.insert(num); + } + for num in [1, 2, 3] { + assert_eq!(tree.find(num), Some(num)); } } }