diff --git a/AK/DisjointChunks.h b/AK/DisjointChunks.h index 8b6033f945..f95bdb5225 100644 --- a/AK/DisjointChunks.h +++ b/AK/DisjointChunks.h @@ -94,6 +94,14 @@ public: DisjointSpans& operator=(DisjointSpans&&) = default; DisjointSpans& operator=(DisjointSpans const&) = default; + Span singular_span() const + { + VERIFY(m_spans.size() == 1); + return m_spans[0]; + } + + SpanContainer const& individual_spans() const { return m_spans; } + bool operator==(DisjointSpans const& other) const { if (other.size() != size()) @@ -440,8 +448,22 @@ private: Vector m_chunks; }; -} +template +struct Traits> : public GenericTraits> { + static unsigned hash(DisjointSpans const& span) + { + unsigned hash = 0; + for (auto const& value : span) { + auto value_hash = Traits::hash(value); + hash = pair_int_hash(hash, value_hash); + } + return hash; + } + constexpr static bool is_trivial() { return false; } +}; + +} #if USING_AK_GLOBALLY using AK::DisjointChunks; using AK::DisjointSpans; diff --git a/Tests/LibRegex/Regex.cpp b/Tests/LibRegex/Regex.cpp index d88f9cd6a3..4063984232 100644 --- a/Tests/LibRegex/Regex.cpp +++ b/Tests/LibRegex/Regex.cpp @@ -1047,6 +1047,8 @@ TEST_CASE(optimizer_alternation) Array tests { // Pattern, Subject, Expected length Tuple { "a|"sv, "a"sv, 1u }, + Tuple { "a|a|a|a|a|a|a|a|a|b"sv, "a"sv, 1u }, + Tuple { "ab|ac|ad|bc"sv, "bc"sv, 2u }, }; for (auto& test : tests) { diff --git a/Userland/Libraries/LibRegex/RegexByteCode.h b/Userland/Libraries/LibRegex/RegexByteCode.h index 38aba05d93..6dc0b8845a 100644 --- a/Userland/Libraries/LibRegex/RegexByteCode.h +++ b/Userland/Libraries/LibRegex/RegexByteCode.h @@ -185,6 +185,23 @@ public: Base::first_chunk().prepend(forward(value)); } + void append(Span value) + { + if (is_empty()) + Base::append({}); + auto& last = Base::last_chunk(); + last.ensure_capacity(value.size()); + for (auto v : value) + last.unchecked_append(v); + } + + void ensure_capacity(size_t capacity) + { + if (is_empty()) + Base::append({}); + Base::last_chunk().ensure_capacity(capacity); + } + void last_chunk() const = delete; void first_chunk() const = delete; @@ -210,20 +227,11 @@ public: void insert_bytecode_compare_string(StringView view) { - ByteCode bytecode; - - bytecode.empend(static_cast(OpCodeId::Compare)); - bytecode.empend(static_cast(1)); // number of arguments - - ByteCode arguments; - - arguments.empend(static_cast(CharacterCompareType::String)); - arguments.insert_string(view); - - bytecode.empend(arguments.size()); // size of arguments - bytecode.extend(move(arguments)); - - extend(move(bytecode)); + empend(static_cast(OpCodeId::Compare)); + empend(static_cast(1)); // number of arguments + empend(2 + view.length()); // size of arguments + empend(static_cast(CharacterCompareType::String)); + insert_string(view); } void insert_bytecode_group_capture_left(size_t capture_groups_count) diff --git a/Userland/Libraries/LibRegex/RegexOptimizer.cpp b/Userland/Libraries/LibRegex/RegexOptimizer.cpp index e765902464..b12f3034fd 100644 --- a/Userland/Libraries/LibRegex/RegexOptimizer.cpp +++ b/Userland/Libraries/LibRegex/RegexOptimizer.cpp @@ -5,9 +5,12 @@ */ #include +#include +#include #include #include #include +#include #include #include #include @@ -815,6 +818,9 @@ void Optimizer::append_alternation(ByteCode& target, ByteCode&& left, ByteCode&& append_alternation(target, alternatives); } +template +using OrderedHashMapForTrie = OrderedHashMap; + void Optimizer::append_alternation(ByteCode& target, Span alternatives) { if (alternatives.size() == 0) @@ -846,154 +852,311 @@ void Optimizer::append_alternation(ByteCode& target, Span alternatives }; #endif - Vector> basic_blocks; - basic_blocks.ensure_capacity(alternatives.size()); + // First, find incoming jump edges. + // We need them for two reasons: + // - We need to distinguish between insn-A-jumped-to-by-insn-B and insn-A-jumped-to-by-insn-C (as otherwise we'd break trie invariants) + // - We need to know which jumps to patch when we're done - for (auto& entry : alternatives) - basic_blocks.append(Regex::split_basic_blocks(entry)); - - Optional left_skip; - size_t shared_block_count = basic_blocks.first().size(); - for (auto& entry : basic_blocks) - shared_block_count = min(shared_block_count, entry.size()); + struct JumpEdge { + Span jump_insn; + }; + Vector>> incoming_jump_edges_for_each_alternative; + incoming_jump_edges_for_each_alternative.resize(alternatives.size()); MatchState state; - for (size_t block_index = 0; block_index < shared_block_count; block_index++) { - auto& left_block = basic_blocks.first()[block_index]; - auto left_end = block_index + 1 == basic_blocks.first().size() ? left_block.end : basic_blocks.first()[block_index + 1].start; - auto can_continue = true; + + for (size_t i = 0; i < alternatives.size(); ++i) { + auto& alternative = alternatives[i]; + // Add a jump to the "end" of the block; this is implicit in the bytecode, but we need it to be explicit in the trie. + // Jump{offset=0} + alternative.append(static_cast(OpCodeId::Jump)); + alternative.append(0); + + auto& incoming_jump_edges = incoming_jump_edges_for_each_alternative[i]; + + auto alternative_bytes = alternative.spans<1>().singular_span(); + for (state.instruction_position = 0; state.instruction_position < alternative.size();) { + auto& opcode = alternative.get_opcode(state); + auto opcode_bytes = alternative_bytes.slice(state.instruction_position, opcode.size()); + + switch (opcode.opcode_id()) { + case OpCodeId::Jump: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::JumpNonEmpty: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::ForkJump: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::ForkStay: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::ForkReplaceJump: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::ForkReplaceStay: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + case OpCodeId::Repeat: + incoming_jump_edges.ensure(static_cast(opcode).offset() + state.instruction_position).append({ opcode_bytes }); + break; + default: + break; + } + state.instruction_position += opcode.size(); + } + } + + struct QualifiedIP { + size_t alternative_index; + size_t instruction_position; + }; + using Tree = Trie, Vector, Traits>, void, OrderedHashMapForTrie>; + Tree trie { {} }; // Root node is empty, key{ instruction_bytes, dependent_instruction_bytes... } -> IP + + size_t common_hits = 0; + size_t total_nodes = 0; + size_t total_bytecode_entries_in_tree = 0; + for (size_t i = 0; i < alternatives.size(); ++i) { + auto& alternative = alternatives[i]; + auto& incoming_jump_edges = incoming_jump_edges_for_each_alternative[i]; + + auto* active_node = ≜ + auto alternative_span = alternative.spans<1>().singular_span(); + for (state.instruction_position = 0; state.instruction_position < alternative_span.size();) { + total_nodes += 1; + auto& opcode = alternative.get_opcode(state); + auto opcode_bytes = alternative_span.slice(state.instruction_position, opcode.size()); + Vector> node_key_bytes; + node_key_bytes.append(opcode_bytes); + + if (auto edges = incoming_jump_edges.get(state.instruction_position); edges.has_value()) { + for (auto& edge : *edges) + node_key_bytes.append(edge.jump_insn); + } + + active_node = static_cast(MUST(active_node->ensure_child(DisjointSpans { move(node_key_bytes) }))); + + if (active_node->has_metadata()) { + active_node->metadata_value().append({ i, state.instruction_position }); + common_hits += 1; + } else { + active_node->set_metadata(Vector { QualifiedIP { i, state.instruction_position } }); + total_bytecode_entries_in_tree += opcode.size(); + } + state.instruction_position += opcode.size(); + } + } + + if constexpr (REGEX_DEBUG) { + Function print_tree = [&](decltype(trie)& node, size_t indent = 0) mutable { + DeprecatedString name = "(no ip)"; + DeprecatedString insn; + if (node.has_metadata()) { + name = DeprecatedString::formatted( + "{}@{} ({} node{})", + node.metadata_value().first().instruction_position, + node.metadata_value().first().alternative_index, + node.metadata_value().size(), + node.metadata_value().size() == 1 ? "" : "s"); + + MatchState state; + state.instruction_position = node.metadata_value().first().instruction_position; + auto& opcode = alternatives[node.metadata_value().first().alternative_index].get_opcode(state); + insn = DeprecatedString::formatted("{} {}", opcode.to_deprecated_string(), opcode.arguments_string()); + } + dbgln("{:->{}}| {} -- {}", "", indent * 2, name, insn); + for (auto& child : node.children()) + print_tree(static_cast(*child.value), indent + 1); + }; + + print_tree(trie, 0); + } + + // This is really only worth it if we don't blow up the size by the 2-extra-instruction-per-node scheme, similarly, if no nodes are shared, we're better off not using a tree. + auto tree_cost = (total_nodes - common_hits) * 2; + auto chain_cost = total_nodes + alternatives.size() * 2; + dbgln_if(REGEX_DEBUG, "Total nodes: {}, common hits: {} (tree cost = {}, chain cost = {})", total_nodes, common_hits, tree_cost, chain_cost); + + if (common_hits == 0 || tree_cost > chain_cost) { + // It's better to lay these out as a normal sequence of instructions. + auto patch_start = target.size(); for (size_t i = 1; i < alternatives.size(); ++i) { - auto& right_blocks = basic_blocks[i]; - auto& right_block = right_blocks[block_index]; - auto right_end = block_index + 1 == right_blocks.size() ? right_block.end : right_blocks[block_index + 1].start; - - if (left_end - left_block.start != right_end - right_block.start) { - can_continue = false; - break; - } - - if (alternatives[0].spans().slice(left_block.start, left_end - left_block.start) != alternatives[i].spans().slice(right_block.start, right_end - right_block.start)) { - can_continue = false; - break; - } - } - if (!can_continue) - break; - - size_t i = 0; - for (auto& entry : alternatives) { - auto& blocks = basic_blocks[i++]; - auto& block = blocks[block_index]; - auto end = block_index + 1 == blocks.size() ? block.end : blocks[block_index + 1].start; - state.instruction_position = block.start; - size_t skip = 0; - while (state.instruction_position < end) { - auto& opcode = entry.get_opcode(state); - state.instruction_position += opcode.size(); - skip = state.instruction_position; - } - - if (left_skip.has_value()) - left_skip = min(skip, *left_skip); - else - left_skip = skip; - } - } - - // Remove forward jumps as they no longer make sense. - state.instruction_position = 0; - for (size_t i = 0; i < left_skip.value_or(0);) { - auto& opcode = alternatives[0].get_opcode(state); - switch (opcode.opcode_id()) { - case OpCodeId::Jump: - case OpCodeId::ForkJump: - case OpCodeId::JumpNonEmpty: - case OpCodeId::ForkStay: - case OpCodeId::ForkReplaceJump: - case OpCodeId::ForkReplaceStay: - if (opcode.argument(0) + opcode.size() > left_skip.value_or(0)) { - left_skip = i; - goto break_out; - } - break; - default: - break; - } - i += opcode.size(); - } -break_out:; - - dbgln_if(REGEX_DEBUG, "Skipping {}/{} bytecode entries from {}", left_skip, 0, alternatives[0].size()); - - if (left_skip.has_value() && *left_skip > 0) { - target.extend(alternatives[0].release_slice(basic_blocks.first().first().start, *left_skip)); - auto first = true; - for (auto& entry : alternatives) { - if (first) { - first = false; - continue; - } - entry = entry.release_slice(*left_skip); - } - } - - if (all_of(alternatives, [](auto& entry) { return entry.is_empty(); })) - return; - - size_t patch_start = target.size(); - for (size_t i = 1; i < alternatives.size(); ++i) { - target.empend(static_cast(OpCodeId::ForkJump)); - target.empend(0u); // To be filled later. - } - - size_t size_to_jump = 0; - bool seen_one_empty = false; - for (size_t i = alternatives.size(); i > 0; --i) { - auto& entry = alternatives[i - 1]; - if (entry.is_empty()) { - if (seen_one_empty) - continue; - seen_one_empty = true; + target.empend(static_cast(OpCodeId::ForkJump)); + target.empend(0u); // To be filled later. } - auto is_first = i == 1; - auto instruction_size = entry.size() + (is_first ? 0 : 2); // Jump; -> +2 - size_to_jump += instruction_size; - - if (!is_first) - target[patch_start + (i - 2) * 2 + 1] = size_to_jump + (alternatives.size() - i) * 2; - - dbgln_if(REGEX_DEBUG, "{} size = {}, cum={}", i - 1, instruction_size, size_to_jump); - } - - seen_one_empty = false; - for (size_t i = alternatives.size(); i > 0; --i) { - auto& chunk = alternatives[i - 1]; - if (chunk.is_empty()) { - if (seen_one_empty) - continue; - seen_one_empty = true; - } - - ByteCode* previous_chunk = nullptr; - size_t j = i - 1; - auto seen_one_empty_before = chunk.is_empty(); - while (j >= 1) { - --j; - auto& candidate_chunk = alternatives[j]; - if (candidate_chunk.is_empty()) { - if (seen_one_empty_before) + size_t size_to_jump = 0; + bool seen_one_empty = false; + for (size_t i = alternatives.size(); i > 0; --i) { + auto& entry = alternatives[i - 1]; + if (entry.is_empty()) { + if (seen_one_empty) continue; + seen_one_empty = true; } - previous_chunk = &candidate_chunk; - break; + + auto is_first = i == 1; + auto instruction_size = entry.size() + (is_first ? 0 : 2); // Jump; -> +2 + size_to_jump += instruction_size; + + if (!is_first) + target[patch_start + (i - 2) * 2 + 1] = size_to_jump + (alternatives.size() - i) * 2; + + dbgln_if(REGEX_DEBUG, "{} size = {}, cum={}", i - 1, instruction_size, size_to_jump); } - size_to_jump -= chunk.size() + (previous_chunk ? 2 : 0); + seen_one_empty = false; + for (size_t i = alternatives.size(); i > 0; --i) { + auto& chunk = alternatives[i - 1]; + if (chunk.is_empty()) { + if (seen_one_empty) + continue; + seen_one_empty = true; + } - target.extend(move(chunk)); - target.empend(static_cast(OpCodeId::Jump)); - target.empend(size_to_jump); // Jump to the _END label + ByteCode* previous_chunk = nullptr; + size_t j = i - 1; + auto seen_one_empty_before = chunk.is_empty(); + while (j >= 1) { + --j; + auto& candidate_chunk = alternatives[j]; + if (candidate_chunk.is_empty()) { + if (seen_one_empty_before) + continue; + } + previous_chunk = &candidate_chunk; + break; + } + + size_to_jump -= chunk.size() + (previous_chunk ? 2 : 0); + + target.extend(move(chunk)); + target.empend(static_cast(OpCodeId::Jump)); + target.empend(size_to_jump); // Jump to the _END label + } + } else { + target.ensure_capacity(total_bytecode_entries_in_tree + common_hits * 6); + + auto node_is = [](Tree const* node, QualifiedIP ip) { + if (!node->has_metadata()) + return false; + for (auto& node_ip : node->metadata_value()) { + if (node_ip.alternative_index == ip.alternative_index && node_ip.instruction_position == ip.instruction_position) + return true; + } + return false; + }; + + struct Patch { + QualifiedIP source_ip; + size_t target_ip; + bool done { false }; + }; + Vector patch_locations; + patch_locations.ensure_capacity(total_nodes); + + auto add_patch_point = [&](Tree const* node, size_t target_ip) { + if (!node->has_metadata()) + return; + auto& node_ip = node->metadata_value().first(); + patch_locations.append({ node_ip, target_ip }); + }; + + Queue nodes_to_visit; + nodes_to_visit.enqueue(&trie); + + // each node: + // node.re + // forkjump child1 + // forkjump child2 + // ... + while (!nodes_to_visit.is_empty()) { + auto const* node = nodes_to_visit.dequeue(); + for (auto& patch : patch_locations) { + if (!patch.done && node_is(node, patch.source_ip)) { + auto value = static_cast(target.size() - patch.target_ip - 1); + target[patch.target_ip] = value; + patch.done = true; + } + } + + if (!node->value().individual_spans().is_empty()) { + auto insn_bytes = node->value().individual_spans().first(); + + target.ensure_capacity(target.size() + insn_bytes.size()); + state.instruction_position = target.size(); + target.append(insn_bytes); + + auto& opcode = target.get_opcode(state); + + ssize_t jump_offset; + auto is_jump = true; + auto patch_location = state.instruction_position + 1; + + switch (opcode.opcode_id()) { + case OpCodeId::Jump: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::JumpNonEmpty: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::ForkJump: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::ForkStay: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::ForkReplaceJump: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::ForkReplaceStay: + jump_offset = static_cast(opcode).offset(); + break; + case OpCodeId::Repeat: + jump_offset = static_cast(0) - static_cast(static_cast(opcode).offset()); + break; + default: + is_jump = false; + break; + } + + if (is_jump) { + VERIFY(node->has_metadata()); + auto& ip = node->metadata_value().first(); + patch_locations.append({ QualifiedIP { ip.alternative_index, ip.instruction_position + jump_offset + opcode.size() }, patch_location }); + } + } + + for (auto const& child : node->children()) { + auto* child_node = static_cast(child.value.ptr()); + target.append(static_cast(OpCodeId::ForkJump)); + add_patch_point(child_node, target.size()); + target.append(static_cast(0)); + nodes_to_visit.enqueue(child_node); + } + } + + for (auto& patch : patch_locations) { + if (patch.done) + continue; + + auto& alternative = alternatives[patch.source_ip.alternative_index]; + if (patch.source_ip.instruction_position >= alternative.size()) { + // This just wants to jump to the end of the alternative, which is fine. + // Patch it to jump to the end of the target instead. + target[patch.target_ip] = static_cast(target.size() - patch.target_ip - 1); + continue; + } + + dbgln("Regex Tree / Unpatched jump: {}@{} -> {}@{}", + patch.source_ip.instruction_position, + patch.source_ip.alternative_index, + patch.target_ip, + target[patch.target_ip]); + VERIFY_NOT_REACHED(); + } } }