diff --git a/AK/HashMap.h b/AK/HashMap.h index 89e6586648..f707c66068 100644 --- a/AK/HashMap.h +++ b/AK/HashMap.h @@ -57,6 +57,9 @@ public: HashSetResult set(const K& key, const V& value) { return m_table.set({ key, value }); } HashSetResult set(const K& key, V&& value) { return m_table.set({ key, move(value) }); } + HashSetResult try_set(const K& key, const V& value) { return m_table.try_set({ key, value }); } + HashSetResult try_set(const K& key, V&& value) { return m_table.try_set({ key, move(value) }); } + bool remove(const K& key) { auto it = find(key); @@ -96,6 +99,7 @@ public: } void ensure_capacity(size_t capacity) { m_table.ensure_capacity(capacity); } + bool try_ensure_capacity(size_t capacity) { return m_table.try_ensure_capacity(capacity); } Optional::PeekType> get(const K& key) const requires(!IsPointer::PeekType>) { diff --git a/AK/HashTable.h b/AK/HashTable.h index dd40c6906d..beebda4a76 100644 --- a/AK/HashTable.h +++ b/AK/HashTable.h @@ -15,6 +15,7 @@ namespace AK { enum class HashSetResult { + Failed = 0, InsertedNewEntry, ReplacedExistingEntry, KeptExistingEntry @@ -184,11 +185,19 @@ public: [[nodiscard]] size_t capacity() const { return m_capacity; } template - void set_from(U (&from_array)[N]) + bool try_set_from(U (&from_array)[N]) { for (size_t i = 0; i < N; ++i) { - set(from_array[i]); + if (try_set(from_array[i]) == HashSetResult::Failed) + return false; } + return true; + } + template + void set_from(U (&from_array)[N]) + { + bool result = try_set_from(from_array); + VERIFY(result); } void ensure_capacity(size_t capacity) @@ -250,36 +259,45 @@ public: } template - HashSetResult set(U&& value, HashSetExistingEntryBehavior existing_entry_behavior = HashSetExistingEntryBehavior::Replace) + HashSetResult try_set(U&& value, HashSetExistingEntryBehavior existing_entry_behavior = HashSetExistingEntryBehavior::Replace) { - auto& bucket = lookup_for_writing(value); - if (bucket.used) { + auto* bucket = try_lookup_for_writing(value); + if (!bucket) + return HashSetResult::Failed; + if (bucket->used) { if (existing_entry_behavior == HashSetExistingEntryBehavior::Keep) return HashSetResult::KeptExistingEntry; - (*bucket.slot()) = forward(value); + (*bucket->slot()) = forward(value); return HashSetResult::ReplacedExistingEntry; } - new (bucket.slot()) T(forward(value)); - bucket.used = true; - if (bucket.deleted) { - bucket.deleted = false; + new (bucket->slot()) T(forward(value)); + bucket->used = true; + if (bucket->deleted) { + bucket->deleted = false; --m_deleted_count; } if constexpr (IsOrdered) { if (!m_collection_data.head) [[unlikely]] { - m_collection_data.head = &bucket; + m_collection_data.head = bucket; } else { - bucket.previous = m_collection_data.tail; - m_collection_data.tail->next = &bucket; + bucket->previous = m_collection_data.tail; + m_collection_data.tail->next = bucket; } - m_collection_data.tail = &bucket; + m_collection_data.tail = bucket; } ++m_size; return HashSetResult::InsertedNewEntry; } + template + HashSetResult set(U&& value, HashSetExistingEntryBehavior existing_entry_behaviour = HashSetExistingEntryBehavior::Replace) + { + auto result = try_set(forward(value), existing_entry_behaviour); + VERIFY(result != HashSetResult::Failed); + return result; + } template [[nodiscard]] Iterator find(unsigned hash, TUnaryPredicate predicate) @@ -369,7 +387,7 @@ private: } } - void rehash(size_t new_capacity) + bool try_rehash(size_t new_capacity) { new_capacity = max(new_capacity, static_cast(4)); new_capacity = kmalloc_good_size(new_capacity * sizeof(BucketType)) / sizeof(BucketType); @@ -378,24 +396,23 @@ private: auto old_capacity = m_capacity; Iterator old_iter = begin(); - if constexpr (IsOrdered) { - m_buckets = (BucketType*)kmalloc(size_in_bytes(new_capacity)); - __builtin_memset(m_buckets, 0, size_in_bytes(new_capacity)); + auto new_buckets = kmalloc(size_in_bytes(new_capacity)); + if (!new_buckets) + return false; - m_collection_data = { nullptr, nullptr }; - } else { - m_buckets = (BucketType*)kmalloc(size_in_bytes(new_capacity)); - __builtin_memset(m_buckets, 0, size_in_bytes(new_capacity)); - } + m_buckets = (BucketType*)new_buckets; + __builtin_memset(m_buckets, 0, size_in_bytes(new_capacity)); m_capacity = new_capacity; m_deleted_count = 0; - if constexpr (!IsOrdered) + if constexpr (IsOrdered) + m_collection_data = { nullptr, nullptr }; + else m_buckets[m_capacity].end = true; if (!old_buckets) - return; + return true; for (auto it = move(old_iter); it != end(); ++it) { insert_during_rehash(move(*it)); @@ -403,6 +420,12 @@ private: } kfree_sized(old_buckets, size_in_bytes(old_capacity)); + return true; + } + void rehash(size_t new_capacity) + { + bool result = try_rehash(new_capacity); + VERIFY(result); } template @@ -424,30 +447,40 @@ private: } } - [[nodiscard]] BucketType& lookup_for_writing(T const& value) + [[nodiscard]] BucketType* try_lookup_for_writing(T const& value) { - if (should_grow()) - rehash(capacity() * 2); - + // FIXME: Maybe overrun the "allowed" load factor to avoid OOM + // If we are allowed to do that, separate that logic from + // the normal lookup_for_writing + if (should_grow()) { + if (!try_rehash(capacity() * 2)) + return nullptr; + } auto hash = TraitsForT::hash(value); BucketType* first_empty_bucket = nullptr; for (;;) { auto& bucket = m_buckets[hash % m_capacity]; if (bucket.used && TraitsForT::equals(*bucket.slot(), value)) - return bucket; + return &bucket; if (!bucket.used) { if (!first_empty_bucket) first_empty_bucket = &bucket; if (!bucket.deleted) - return *const_cast(first_empty_bucket); + return const_cast(first_empty_bucket); } hash = double_hash(hash); } } + [[nodiscard]] BucketType& lookup_for_writing(T const& value) + { + auto* item = try_lookup_for_writing(value); + VERIFY(item); + return *item; + } [[nodiscard]] size_t used_bucket_count() const { return m_size + m_deleted_count; } [[nodiscard]] bool should_grow() const { return ((used_bucket_count() + 1) * 100) >= (m_capacity * load_factor_in_percent); }