diff --git a/AK/RedBlackTree.h b/AK/RedBlackTree.h index acdc2b9ed4..d6b4638050 100644 --- a/AK/RedBlackTree.h +++ b/AK/RedBlackTree.h @@ -133,16 +133,19 @@ protected: static Node* find_smallest_not_below(Node* node, K key) { + Node* candidate = nullptr; while (node) { - if (node->key >= key && (!node->left_child || node->left_child->key < key)) + if (node->key == key) return node; - if (node->key <= key) + if (node->key <= key) { node = node->right_child; - else + } else { + candidate = node; node = node->left_child; + } } - return node; + return candidate; } void insert(Node* node) diff --git a/Tests/AK/TestRedBlackTree.cpp b/Tests/AK/TestRedBlackTree.cpp index 51f9549957..f6d1ddebac 100644 --- a/Tests/AK/TestRedBlackTree.cpp +++ b/Tests/AK/TestRedBlackTree.cpp @@ -86,3 +86,45 @@ TEST_CASE(clear) test.clear(); EXPECT_EQ(test.size(), 0u); } + +TEST_CASE(find_smallest_not_below_iterator) +{ + RedBlackTree test; + + for (size_t i = 0; i < 8; i++) { + auto above_all = test.find_smallest_not_below_iterator(i); + EXPECT(above_all.is_end()); + + test.insert(i, i); + + auto only_just_added_i_is_not_below = test.find_smallest_not_below_iterator(i); + EXPECT(!only_just_added_i_is_not_below.is_end()); + EXPECT_EQ(only_just_added_i_is_not_below.key(), i); + } + + { + auto smallest_not_below_two = test.find_smallest_not_below_iterator(2); + EXPECT(!smallest_not_below_two.is_end()); + EXPECT_EQ(smallest_not_below_two.key(), 2u); + } + + test.remove(2); + + { + auto smallest_not_below_two_without_two = test.find_smallest_not_below_iterator(2); + EXPECT(!smallest_not_below_two_without_two.is_end()); + EXPECT_EQ(smallest_not_below_two_without_two.key(), 3u); + } + + { + auto smallest_not_below_one = test.find_smallest_not_below_iterator(1); + EXPECT(!smallest_not_below_one.is_end()); + EXPECT_EQ(smallest_not_below_one.key(), 1u); + } + + { + auto smallest_not_below_three = test.find_smallest_not_below_iterator(3); + EXPECT(!smallest_not_below_three.is_end()); + EXPECT_EQ(smallest_not_below_three.key(), 3u); + } +}